Students Info¶

Hello there!

We are Nadav and Amir - a pair of young Computer Science students at The Open University of Israel. We're both in the final semester of our studies, after which we'll begin our military service. We're both passionate about Data Science and Machine Learning. We went through several courses in this field together - making it a natural choice for us to team up, and so we did.


Introduction¶

Looking for a Problem to Solve¶

The problem we are going to tackle in this project is quite an interesting one, that we managed to get to very unexpectedly.

After teaming up, we brainstormed ideas for weeks, trying to find the most enlightening problem we could solve. We wanted it to be meaningful: to use state-of-the-art data science technologies to solve a problem that could, in theory, help the world in one way or another. As a pair of curious young students, we also wished for it to be a difficult challenge that will allow us to express our creativity searching for a successful solution. Looking for such an enlightening problem was hard - until suddenly, out of the blue, it hit us.

One day my (Nadav) brother reached to me, telling me that the older brother of a friend of his needs the help of a data scientist. He connected us, and thus we met Doctor Paz Kelmer from Semmelweis University in Budapest. He was just finishing his MD degree with outstanding achievements, and was working on an innovative research in the field of Neuroscience. He presented to us an idea that could speed up the research in this field. It was a perfect problem for us, and we decided to dedicate our final data science project to solve it.

Some Neuroscience Background¶

Let's start by establishing some basic knowledge we'll use to understand the problem at hand.

The human brain is the most important organ in our bodies. It is so complicated that scientists don't even have (yet) a thorough understanding of the way it operates. It controls the entire body via the nervous system - through which different parts of the body communicate.

Its main building blocks are special cells, called neurons. These cells posess electrical charges (voltage): there is an electric potential difference between their inside and their outside, due to the presence of different ions (mainly sodium and potassium, but also chlorine and calcium). Their capacity is around a few tens of millivolts. Neurons are connected by small "wires" called axons. Once the voltage in a neuron exceeds some threshold, it starts to fire electric pulses through the axons (reminds you of something?). We'll not delve deeper into the other components of the neuron, since they are not relevant to gain an understanding of the problem.

Figure 1: The neuron.

There are principal neurons and interneurons. Principal neurons and their networks underlie local information processing/storage and represent the major sources of output from any brain region, whereas interneurons, by definition, have local axons that govern ensemble activity.

A nucleus is a cluster of neurons in the central nervous system. The neurons in one nucleus usually have similar connections and functions. At a certain region of the human brain, there is a group of nuclei called the basal ganglia. These are responsible for motor control, motor learning, executive functions and behaviours, and emotions. The largest structure in it is called the striatum, and it's a critical component of the motor and reward systems. It coordinates multiple aspects of cognition, including both motor and action planning, decision-making, motivation, reinforcement, and reward perception. It contains a structure called the putamen, whose primary function is to regulate movements at various stages (e.g. preparation and execution) and influence various types of learning.

Let's move our discussion to a certain type of interneurons, called calretinin interneurons (CR+ neurons). These are the most abundant interneurons in the human striatum and there is an evolutionary trend where primates have 10X more of them than rodents. However, their exact function is still unknown and there is still much to be discovered. They have 2 properties we are interested in:

  • They produce GABA - an amino acid that locally regulates the electric charges of neurons: this molecule connects to following neurons, injects chlorine into them, and let's potassium out of them. This causes the electric potential difference between their inside and their outside to shrink. In a way, they "turn off" following neuorns. That means that they play an interesting role in the "logic" of the striatum.

  • They produce calretinin - a protein that buffers calcium. The protein's exact role is actually not relevant for us, we only care about the fact that CR+ neurons are rich of it - it can be found all over their membranes (their outer shells).

Doctor Kelmer and his team identified 3 groups of CR+ neurons: small cells with 6-15 um diameter, medium cells with 15-25 um diameter, and large cells with 25-60um diameter.

Figure 2: CR+ neurons are divided into 3 groups, based on the length of their diameter.

Schizophrenia and Doctor Kelmer's Research¶

Let's talk a bit about schizophrenia. Schizophrenia is a serious chronic neuropsychiatric disorder ranked top 10 of global burden of disease by the World Health Organisation, with a global prevalence of about 1%. The symptoms include positive productive symptoms such as hallucinations, negative deficit symptoms such as apathy, and cognitive symptoms such as poor learning skills. There are no objective biomarkers available for schizophrenia. Diagnosis depends upon the subjective clinicians' opinion based on questioners and criteria provided by the Diagnostic and Statistical Manual of Mental Disorders.

The functions of the basal ganglia area mentioned above are tightly connected to these symptoms. It is also known that diseases involving this area of the brain are associated with psychosis such as Huntington. Doctor Kelmer and his team tried to find a correlation between schizophrenia and the structure of this area of the brain. Since schizophrenic patients have bizarre movements, they focused on the putamen (depicted below) and the distribution of CR+ cells in it.

Figure 3: The putamen - sample ID12718 (scanned using a whole-slide scanner after IHC-staining was applied, as described below).

But why choose CR+ cells? As mentioned above, they play an interesting role in the "logic" of the striatum, but more than that - they can be easily located. Remember the calretinin protein they produce, that can be found all over their membranes? It responds to some sort of antibody, which stains them. By applying it to samples of the striatum, we can cause those CR+ cells to show up! This process is an example of what's known as immunohistochemistry (IHC) staining - a method that uses antibodies to check for certain antigens (markers) in tissue samples.

They applied IHC staining to putamen brain sammples and scanned them using whole-slide scanning. Then, they manually marked the CR+ cells that showed up in them. The team found some strong correlations between the distribution of CR+ cells in the putamen to the presence of schizophrenia. Their research was very successful and was recognized as the best one at a conference they presented it in. However, it was not yet published.

Let's take a moment to think about the process that the team went through: they had to sit for hours upon hours to mark down tens of thousands of CR+ neurons that showed up in the scans. This tedious manual marking of cells is the method that research teams have been using for years (including ones that explored similar concepts), which - among other things - made progress in this field very slow.

This is where we come into play.

The Problem and the Data¶

Doctor Paz Kelmer thought of the idea of an automated system for marking cells in IHC-stained whole slide images (WSIs). He needed someone with knowledge in computer vision and reached out to me (Nadav) through our brothers (who are friends from work). I came with this idea to Amir, and we were both excited about dedicating our project to this meaningful task.

The data we were given by Paz to work with was a couple of the WSIs that he and his research team manually scanned, marking any CR+ neurons they could find. Each such cell is marked along its diameter by a straight line which we'll call a cell mark, stretching between the furthest edges of the cell.

Figure 4: IHC-stained CR+ cells were manualy marked by Doctor Kelmer's research team.

The WSIs are divided into image areas, each has some interest for the neuroscientists. It is important to note that the team did not mark cells outside these areas.

Figure 5: The putamen - sample ID12718, divided into the relevant image areas (marked by the research team).

Each WSI in the data we have consists of two files:

  • A TIFF image file, containing the (huge) whole-slide-image itself.

  • An XML annotations file, containing information about the layout of the image (the image areas, the cell marks, and additional information such as the scale of the image in real life).

Paz wanted us to use those samples to create a system that could automatically scan areas of the image, mark any CR+ neurons in them, and obtain the distribution of the cells' diameter - all without human interference. He emphasized a few specifications that the system should follow:

  1. Cells should be marked on their longest diameter.
  2. The desired cells should be round objects, but not perfect circles.
  3. The desired cells should be brown (or maybe brown & blueish, but a blue object should not be marked if it has no brown in it).
  4. Only mark cells that are at least 4um in diameter, but no more than 60um.
  5. False positives are less wanted - it is better to miss a cell than to mark something that's not a cell. The reason for that is to make human correction easier (it's easier to mark a missed cell than to delete a false mark).

He also noted that it is very likely that they missed some CR+ neurons while manually placing markers on the images.


Preparations¶

In [ ]:
!pip install pytiff
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytiff
  Downloading pytiff-0.8.1.tar.gz (1.3 MB)
     |████████████████████████████████| 1.3 MB 28.5 MB/s 
Building wheels for collected packages: pytiff
  Building wheel for pytiff (setup.py) ... done
  Created wheel for pytiff: filename=pytiff-0.8.1-cp37-cp37m-linux_x86_64.whl size=839345 sha256=83a66299b94fe95b961bbbc48f5815a3701f2d739c7a7cf9e5ad64d324b50a32
  Stored in directory: /root/.cache/pip/wheels/10/f5/61/defc30c7ec905f5b7cce3452360d3c5f798d339d78d2f020ca
Successfully built pytiff
Installing collected packages: pytiff
Successfully installed pytiff-0.8.1
In [ ]:
# import necessary packages

import random
import math
from tqdm import tqdm
import time
import os
from xml.dom import minidom
from xml.etree import ElementTree

from pytiff import Tiff
from shapely.geometry import Point, Polygon, MultiPolygon, LineString

%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation, image as img
from matplotlib import colors
from IPython.display import HTML
from PIL import Image
from google.colab.patches import cv2_imshow

import numpy as np
import cv2
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import gaussian_blur, convert_image_dtype

import pandas as pd
import seaborn as sns
import scipy
from scipy.stats import multivariate_normal
from scipy import ndimage

import skimage
from skimage import measure
from skimage.morphology import erosion, opening
from skimage.draw import disk
In [ ]:
# mount drive and move to the project directory
from google.colab import drive, files
drive.mount('/content/drive')
%cd /content/drive/My\ Drive/DSProject/
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/DSProject
In [ ]:
# set manual seed for PyTorch (for reproducible results)
torch.manual_seed(0)
Out[ ]:
<torch._C.Generator at 0x7f9754045430>
In [ ]:
# Get available device for PyTorch, preferably GPU
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
print('Using device:', device)
Using device: cuda:0

Quick note: we decided to write some "thick" utility functions (regarding shallow matters such as formatting print output while discussing less important issues) in a separate python file named DSProjectUtils.py (found in the project's directory). We did this to make our discussions look cleaner. Of course, we don't use those functions while writing more central parts of the project.

In [ ]:
# import our utility package
import DSProjectUtils

Preparing the Data¶

Each putamen sample was scanned into a whole-slide-image represented by a huge TIFF image file - stretching tens of thousands of pixels, and taking a few giga-bytes of memory space. Those slides were loaded into a program called Aperio ImageScope. This program allows the neuroscientists to zoom into those whole-slide-images and annotate different areas in them (it also offers some other types of functionality, but these are not relevant for us). It then stores the data regarding those annotations (alongside other information) into an XML annotations file of the same name.

Each WSI we are given with cosists of those 2 files - a TIFF image file and an XML annotations file (both with the same name). It is our job to make sense of it all, and extract the relevant information. We care about 4 things:

  • The image data - in the TIFF image file.
  • The location of the relevant image areas - in the XML annotations file.
  • The cell-marks - in the XML annotations file.
  • The scale of the image (in real life) - in the XML annotations file.

So, let's jump right into it! For exploration's sake, we'll examine a specific sample labeled ID12718. The TIFF image file associated with it can be found in samples/ID12718.tiff, and the XML annotations file associated with it can be found in samples/ID12718.xml.

In [ ]:
tiff_path = 'samples/ID12718.tiff'
xml_path = 'samples/ID12718.xml'

Loading the Image Data¶

We will first try to figure out how to handle the image data contained in the associated TIFF file. Let's start by getting some basic properties of the file:

In [ ]:
size = os.path.getsize(tiff_path) / 2**30 # in gibibytes
height, width = Tiff(tiff_path).shape[:2]

print("Putamen sample image:")
print(f'\tSize: {size : .2f} GiB')
print(f'\tResolution: {height}px X {width}px')
Putamen sample image:
	Size:  4.18 GiB
	Resolution: 41203px X 49751px

This image is HUGE - we most certainly don't want to load it all at once to the main memory! We need a mechanism to load only small patches of it when required.

We surfed the web looking for a way to accomplish that and stumbled upon a small open-source library called Pytiff. We installed it using pip in the Preparations section above. The library associates what's called a handle object with each TIFF image via the 'Tiff' function which we already imported. This object can then be subscripted in the most "pythonic" way possible to load patches of the image into NumPy arrays in main memory:

In [ ]:
# create handle and load patch
img_handle = Tiff(tiff_path)
example_patch = img_handle[5000:5250, 5000:5250]
print(f'Type of TIFF handle object: {type(img_handle)}')
print(f'Type of example patch: {type(example_patch)}')

# show it
plt.imshow(example_patch)
plt.title('Example Patch')
plt.axis('off')
None
Type of TIFF handle object: <class 'pytiff._pytiff.Tiff'>
Type of example patch: <class 'numpy.ndarray'>

Excellent! That covers the problem of loading the image data.

Preparing the Annotations¶

Understanding the XML Annotations File¶

Next, we need to understand what are those mysterious "image areas", and extract the cell marks found in them. For this task, we'll need to understand the XML annotations file. Let's start by loading it, and have a quick first look at its contents:

In [ ]:
# load and display XML ElementTree
root = ElementTree.parse(xml_path).getroot()
print(DSProjectUtils.format_etree(root, 2))
<Annotations MicronsPerPixel="0.501900">
	<Annotation Id="20526" Name="" ReadOnly="0" NameReadOnly="0" LineColorReadOnly="0" Incremental="0" Type="4" LineColor="65280" Visible="1" Selected="0" MarkupImagePath="" MacroName="">
		<Attributes>
			<Attribute Name="Description" Id="701418" Value="" />
		</Attributes>
		<Regions>
			<RegionAttributeHeaders>
				<AttributeHeader Id="271155" Name="Region" ColumnWidth="-1" />
				<AttributeHeader Id="271156" Name="Length" ColumnWidth="-1" />
				[... 3 more AttributeHeader nodes]
			</RegionAttributeHeaders>
			<Region Id="222293" Type="0" Zoom="0.042849" Selected="0" ImageLocation="" ImageFocus="0" Length="74736.1" Area="125998015.9" LengthMicrons="37510.0" AreaMicrons="31739357.3" Text="" NegativeROA="0" InputRegionId="0" Analyze="1" DisplayId="1">
				<Attributes />
				<Vertices>
					<Vertex X="7374.437049" Y="36663.447974" />
					<Vertex X="7351.099282" Y="36663.447974" />
					[... 1748 more Vertex nodes]
				</Vertices>
			</Region>
			<Region Id="267376" Type="4" Zoom="1" Selected="0" ImageLocation="" ImageFocus="0" Length="23.4" Area="0.0" LengthMicrons="11.8" AreaMicrons="0.0" Text="" NegativeROA="0" InputRegionId="0" Analyze="0" DisplayId="2">
				<Attributes />
				<Vertices>
					<Vertex X="10444" Y="25229" />
					<Vertex X="10429" Y="25247" />
				</Vertices>
			</Region>
			[... 877 more Region nodes]
		</Regions>
	</Annotation>
	<Annotation Id="20527" Name="" ReadOnly="0" NameReadOnly="0" LineColorReadOnly="0" Incremental="0" Type="4" LineColor="65535" Visible="1" Selected="0" MarkupImagePath="" MacroName="">
		<Attributes>
			<Attribute Name="Description" Id="701419" Value="" />
		</Attributes>
		<Regions>
			<RegionAttributeHeaders>
				<AttributeHeader Id="271160" Name="Region" ColumnWidth="-1" />
				<AttributeHeader Id="271161" Name="Length" ColumnWidth="-1" />
				[... 3 more AttributeHeader nodes]
			</RegionAttributeHeaders>
			<Region Id="222294" Type="0" Zoom="0.042849" Selected="0" ImageLocation="" ImageFocus="0" Length="27148.5" Area="40982837.6" LengthMicrons="13625.8" AreaMicrons="10323725.5" Text="" NegativeROA="0" InputRegionId="0" Analyze="1" DisplayId="1">
				<Attributes />
				<Vertices>
					<Vertex X="10548.519236" Y="36033.264128" />
					<Vertex X="10571.857003" Y="36033.264128" />
					[... 490 more Vertex nodes]
				</Vertices>
			</Region>
			<Region Id="269338" Type="4" Zoom="1" Selected="0" ImageLocation="" ImageFocus="0" Length="15.8" Area="0.0" LengthMicrons="7.9" AreaMicrons="0.0" Text="" NegativeROA="0" InputRegionId="0" Analyze="0" DisplayId="2">
				<Attributes />
				<Vertices>
					<Vertex X="10681" Y="35848" />
					<Vertex X="10676" Y="35863" />
				</Vertices>
			</Region>
			[... 266 more Region nodes]
		</Regions>
	</Annotation>
	[... 8 more Annotation nodes]
</Annotations>

Whoa, that's a lot of information! Let's try to make sense of it.

First, we can see that the root of the XML is an element called Annotations, which has an attribute: MicronsPerPixel. This has to be the scale of the image in real life (micron = micro meter = $10^{-6}$ meter):

In [ ]:
scale = float(root.attrib['MicronsPerPixel'])
print(f'Image scale: {scale}um / px.')
Image scale: 0.5019um / px.

The second thing that we can observe is that the file is separated into sections, each represented by an Annotation node. Each of these has many types of information in it, which mostly seem to be ID information that is probably used by the ImageScope program. We don't care about those - we only care about the location and the meaning of the annotations themselves, so we'll ignore most of it.

One thing that strikes the eye is the fact that each Annotation node has a Regions child node, filled with some type of nodes called Region. What makes them interesting is the fact that each of them has a Vertices child-node, filled with a few hundreds / only 2 Vertex nodes - each having an X attribute and a Y attribute.

That looks promising - we finally observe coordinates in the annotations XML file. But what are these Region nodes, and why do some of them contain hundreds of vertices, and some contain only 2? Well, we know that each cell mark is an individual annotation element, which looks like a straight line stretching between 2 points in the image - so these 2-vertex Region nodes may be cellmarks. Let's check this hypothesis by observing the coordinates of such an element in ImageScope:

In [ ]:
possible_cellmark = root[0][1][2]
print(DSProjectUtils.format_etree(possible_cellmark))
<Region Id="267376" Type="4" Zoom="1" Selected="0" ImageLocation="" ImageFocus="0" Length="23.4" Area="0.0" LengthMicrons="11.8" AreaMicrons="0.0" Text="" NegativeROA="0" InputRegionId="0" Analyze="0" DisplayId="2">
	<Attributes />
	<Vertices>
		<Vertex X="10444" Y="25229" />
		<Vertex X="10429" Y="25247" />
	</Vertices>
</Region>
Figure 6: Screenshot from Aperio ImageScope, opened on the sample labeled ID12718. The coordinates of the shown location in the image are written at the bottom.

That is indeed a cellmark! We checked a few other examples of Region nodes with 2 vertices, and all of them seemed to be cell marks.

Next, we need to figure out the image areas. Let's delete all of the cellmarks from the XML annotations file (the Region nodes with 2 vertices), and take a look at what's left in ImageScope:

Figure 7: Screenshot from Aperio ImageScope, opened on the sample labeled ID12718 after deleting the cell-marks from the associated XML annotations file.

We can see that we're left with annotations marking the edges of the image areas (and some green dots we'll talk about later). But how are these represented? Well, what's left in the XML are Region nodes, containing hundreds of vertices and stretching tens of thousands of pixels (see the Length attribute), for instance:

In [ ]:
lengthy_region = root[2][1][1]
print(DSProjectUtils.format_etree(lengthy_region, 5))
<Region Id="222295" Type="0" Zoom="0.042849" Selected="0" ImageLocation="" ImageFocus="0" Length="46299.2" Area="56640747.8" LengthMicrons="23237.6" AreaMicrons="14268009.8" Text="" NegativeROA="0" InputRegionId="0" Analyze="1" DisplayId="1">
	<Attributes />
	<Vertices>
		<Vertex X="19043.466502" Y="30828.942039" />
		<Vertex X="19043.466502" Y="30805.604272" />
		<Vertex X="19043.466502" Y="30735.590971" />
		<Vertex X="19043.466502" Y="30665.577669" />
		<Vertex X="19043.466502" Y="30642.239902" />
		[... 788 more Vertex nodes]
	</Vertices>
</Region>

These must have something to do with the image areas - let's zoom in with ImageScope on the one we just looked at:

Figure 8: Screenshot from Aperio ImageScope, after zooming in on the same sample shown in Figure 7. The coordinates of the location circled in purple are marked at the bottom of the image.

Huh! That first Vertex is the edge of the boundary of the bottom-left red image area shown in Figure 7 (at its bottom-right, where it intersects with yellow and cyan image areas). It can also be seen that such image areas are represented by long polylines marking their boundaries. So - these huge Region nodes must be the boundaries of the image areas!

However - why are there multiple of these in some of the Annotation nodes? For example - why does the same Annotation node containing the area boundary we've just seen, contains the following "huge" Region node as well:

In [ ]:
another_lengthy_region = root[2][1][2]
print(DSProjectUtils.format_etree(another_lengthy_region, 5))
<Region Id="222296" Type="0" Zoom="0.210000" Selected="0" ImageLocation="" ImageFocus="0" Length="4378.5" Area="480691.6" LengthMicrons="2197.6" AreaMicrons="121088.0" Text="" NegativeROA="1" InputRegionId="0" Analyze="1" DisplayId="2">
	<Attributes />
	<Vertices>
		<Vertex X="16847.761974" Y="29014.476231" />
		<Vertex X="16843.000069" Y="29014.476231" />
		<Vertex X="16838.238164" Y="29014.476231" />
		<Vertex X="16833.476259" Y="29019.238136" />
		<Vertex X="16828.714354" Y="29019.238136" />
		[... 739 more Vertex nodes]
	</Vertices>
</Region>

Let's zoom into it with ImageScope:

Figure 9: Screenshot from Aperio ImageScope, after zooming in on the same sample shown in Figure 7. The coordinates of the location circled in purple are marked at the bottom of the image.

We couldn't figure out what these are, so we asked our friend who created these annotations - Doctor Paz Kelmer. He told us that these are regions of the image that are NOT part of the corresponding image area (in this case - the bottom-left red image area). To differ between the 2 types of polyline boundaries, we'll call those that mark what's inside an image area positive boundaries, and those that mark what's discluded from an image area negative boundaries.

But how can we differ between them in the XML annotations file? After checking a few more of these Region nodes stretching thousands of pixels, we noticed a repeating difference in the NegativeROA attribute of the nodes representing positive and negative boundaries. The positive ones have this atribute set to "0", and the negative ones have this attribute set to "1".

One last thing we need to pay attention to before we can start writing code is the appearance of some weird green dots in figure 7. Let's zoom in on one of them:

Figure 10: Screenshot from Aperio ImageScope, after zooming in on one of the "unknown green-dots" in Figure 7.

Well, it seems like the research team did more than just marking cells with straight segments - they even surrounded some with polylines! At first glance, this seems like a better source of data than the simple cellmarks. However, we only have a few tens of these, and they're all clustered at the bottom-left of the image (which means that they're a poor representation of the data since they represent only a small part of it). So on second thought - they're not that good at all. Therefore, we decided not to use these in our algorithms.

Therefore we need to somehow identify these cell-perimeter markings and discard them. We couldn't find any consistent difference between those annotations and the positive boundaries, except for the fact that their Length attribute is way smaller (since the perimeter of a neuron is way smaller than the perimeter of a huge image area containing many many neurons). So, we decided to differ between them using some threshold for the Length attribute. After examining their differences we decided on a threshold of 1000pxs.

Before we continue, let us list a few more observations we had while examining elements of the XML annotations file:

  1. If it was not already clear, each Annotation node represents a DIFFERENT image area. The cellmarks and area boundaries in one of them are all parts of the same image area.

  2. There is ONLY ONE positive boundary associated with each image area, so there is only one Region node representing a positive boundary in each Annotation node.

  3. We can differentiate between area boundaries and cellmarks by the Type attribute of their Region: area boundaries have Type="0" and cellmarks have Type="4". However, cell-perimeter markings also have Type="0".

  4. Negative boundaries are the only Region nodes with attribute NegativeROA="1".

Processing the XML Annotations File¶

Let's sum up our knowledge about the XML annotations file:

  • It is made up of Annotation nodes that represent image areas.
  • Each Annotation has a single Regions child-node, which has multiple Region child-nodes that represent some kind of annotations in the corresponding image area.
  • Each Region node has multiple Vertex descendant nodes that represent the vertices that make up the corresponding annotation.
  • Each Vertex has two attributes X, Y that represent its coordinates inside the image.

There are 4 types of annotations (Region nodes):

  1. Positive boundaries: polylines surrounding each image area. Can be identified by having attributes Type="0", NegativeROA="0", and Length >= 1000.
  2. Negative boundaries: polylines surrounding areas that are discluded from an image area ("holes"). Can be identified by having attribute NegativeROA="1".
  3. Cellmarks: line segments marking the diameter of CR+ cells. Can be identified by having attributes Type=4.
  4. Cell perimeter markings: polylines surrounding some of the cells. These are all the annotations that are left, and as mentioned above we decided to discard those.

Let's write some code to process XML annotation files. We'll write a function that extracts the relevant information in it into some easy-to-use python objects. We'll use the standard shapely package to represent the elements of the image - the boundaries will be represented by Polygon objects (negative boundaries will be "holes" in them), and cellmarks will be represented by LineString objects.

In [ ]:
'''
A utility function for extracting vertices of a Region node in an XML file associated
  with an IHC-staining image.
Input - an xml.dom.minidom.Element associated with the Region node.
Output - list of pairs of floats, representing the coordinates of the vertices of the
  region.
'''
def extract_verts(region):
    # extarct Vertex child nodes
    verts = region.getElementsByTagName('Vertex')
    # extract their coordinates
    verts_cords = [(float(vert.attributes['X'].value),
                      float(vert.attributes['Y'].value))
                   for vert in verts]
    return verts_cords


'''
Extracts the relevant information from an XML annotations file associated with a
  whole-slide-image, produced by Aperio ImageScope.

Input:
  > 'src_xml' - path for the XML annotations file containing.
  > 'bound_min_len' - the minimum Length attribute of a region in the XML (with
    Type="0" and NegativeROA="0") to be considered as a positive boundary of an
    area in the image. Default: 1000.

Returns a pair - (img_areas, microns_per_pixel), where 'microns_per_pixel' is a float
  that represents the scale of the image (um/px), and 'img_areas' is a list of python
  dictionaries representing the different areas in the image.
Each dictionary in 'img_areas' contains 2 entries:
  > 'boundary' - a shapely Polygon representing the boundary of the area (its holes
    represent the negative boundaries).
  > 'cellmarks' - list of cell marks, each represented by a shapely LineString.
'''
def process_wsi_annots(src_xml, bound_min_len=1000):

    # extract the relevant information from the xml
    with minidom.parse(src_xml) as dom:

        img_areas = [] # to store the area dictionaries

        # check every image area
        for annotation in dom.getElementsByTagName('Annotation'):
            regions = annotation.getElementsByTagName('Regions')[0]

            positive_bound = None # to store list of vertices of psoitive boundary
            negative_bounds = [] # list of lists of vertices of negative boundaries
            cellmarks = [] # list of pairs of vertices of cellmarks

            # iterate through the 'Region' child nodes
            for child_node in regions.childNodes:
                if isinstance(child_node, minidom.Element) and \
                  child_node.tagName == 'Region':
                    # figure-out the region's type according to its attributes
                    region_type = int(child_node.attributes['Type'].value)
                    is_negative = (child_node.attributes['NegativeROA'].value
                                   == "1")
                    region_len = float(child_node.attributes['Length'].value)

                    if region_type == 0 and not is_negative and \
                      region_len >= bound_min_len:
                         # it's a positive area boundary (there is only one in the area)
                         vertices = extract_verts(child_node)
                         positive_bound = vertices

                    elif is_negative:
                        # it's a negative area boundary
                        vertices = extract_verts(child_node)
                        negative_bounds.append(vertices)

                    elif region_type == 4:
                        # it's a cellmark
                        cellmarks.append(LineString(extract_verts(child_node)))

            # if the current area is not empty - add it
            if positive_bound != None:
                img_area = {'boundary' : Polygon(positive_bound,
                                                holes=negative_bounds).buffer(0),
                            'cellmarks' : cellmarks}
                img_areas.append(img_area)

        # next - collect the microns-per-pixel attribute of the WSI
        annotations = dom.getElementsByTagName('Annotations')[0]
        microns_per_pixel = float(annotations.attributes['MicronsPerPixel'].value)

    return img_areas, microns_per_pixel

Let's use it to process our putamen sample:

In [ ]:
img_areas, microns_per_pixel = process_wsi_annots(xml_path)
print(f'Image scale: {scale}um / px.')
print(f'The image contains {len(img_areas)} areas.')
Image scale: 0.5019um / px.
The image contains 9 areas.

Let's verify the area boundaries by displaying a couple of them:

In [ ]:
for img_area in img_areas[:2]:
    display(img_area['boundary'])

This looks just about right!

Tiling¶

Our task is to automatically scan whole-slide-images, detect some sort of objects in certain areas of them, and then mark those objects in some specific manner. However, as we've already seen, each image is way too large to load and process in main memory. Moreover, large parts of the images are outside the boundaries that we care about (the image areas).

Let's perform a quick analysis of this issue:

In [ ]:
# extract information about our image
size = os.path.getsize(tiff_path) / 2**30
height, width = img_handle.shape[:2]
total_area = height*width

print("Putamen sample image:")
print(f'\tSize: {size :.2f} GiB')
print(f'\tResolution: {height}px X {width}px')
print(f'\tArea: {total_area :.2e} px^2')

# use the img_areas list extracted previously to evaluate the relevant areas
relevant_area = sum([img_area['boundary'].area for img_area in img_areas])
print(f'\tTotal area of relevant regions: {relevant_area :.2e} px^2')
print(f'\tPart of relevant regions: {relevant_area / total_area :.2%}')
Putamen sample image:
	Size: 4.18 GiB
	Resolution: 41203px X 49751px
	Area: 2.05e+09 px^2
	Total area of relevant regions: 7.48e+08 px^2
	Part of relevant regions: 36.50%

So if we process the entire image, almost 2 thirds of our work will go for nothing!

The method we came up with to overcome both of these problems is called tiling. The idea is to divide the image into small tiles which will be loaded separately (thus resolving the memory issue), and discard tiles that are not part of any image area (thus resolving the issue of the non-relevant regions):

Figure 11: Tiling - dividing whole-slide-images into small tiles, and extracting those that are part of some relevant area. This is putamen sample ID12718, divided into tiles of size 1024px by 1024px. Tiles that are part of some relevant image area are colored.

However, how can we determine that a tile is part of some image area? Well, we had a few ideas:

  • Our first interpretation was that a tile should be fully contained in the area to which it belongs. However, after implementing that, we observed that we missed a lot of good tiles around the edges of image areas (and around holes), that just a small part of them went beyond the area boundary. I.e - we missed a lot of data.

  • Our next interpretation was that a tile should only intersect with the area that it belongs to. But the opposite issue popped up - we had a lot of tiles which were pretty much outside any relevant area but were included because some small part of them intersected with the very edge of one. This is very problematic since some of these tiles had CR+ cells in them, which were not marked simply because they were outside of any relevant region. So if we would have built a learning model to locate CR+ cells and used this data to fit it, it'd have tried to avoid those unmarked CR+ cells (i.e. - we would impose a great deal of noise on the data).

  • At last we came up with a compromise: a tile should have a large enough intersection with the image area that it belongs to. How large? Well, that can be a tuneable hyperparameter. We played with it and decided that 50% of the area of the tile is a good compromise.

    All in all, the criterion we decided on for a tile $T$ to be considered as part of an image area $A$ is that at least 50% of the area of $T$ is contained in $A$.

We'll use the functionality of the standard shapely library to calculate the area of intersection between a tile and an image area. The way to do this is using the intersection function of the shapely.geometry.polygon.Polygon class. This function takes another shapely shape and returns a shape that represents their intersection. We can then use the area field of that shape to determine if it's large enough.

However, in practice, we noticed that the intersection function requires some pretty heavy calculations, and it takes a while to calculate the intersection of each tile with every relevant region. We used two methods to speed things up:

  1. We used the light-weight function intersects of shapely.geometry.polygon.Polygon to make sure we run the heavy intersection function only on pairs of shapes that actually intersect.

  2. We decided to make the whole process of tiling and attaching tiles to image areas as a preprocess, so it won't slow things down in the heavier parts of the project.

In [ ]:
'''
Converts a tile represented by top-left coordinates, width and height, to a shapely
  polygon,
'''
def tile2polygon(tile_x, tile_y, tile_w, tile_h):
    tile = Polygon(((tile_x, tile_y), (tile_x, tile_y+tile_h),
                    (tile_x+tile_w, tile_y+tile_h), (tile_x+tile_w, tile_y)))
    return tile

'''
Locates the image area of a whole-slide-image associated with a tile.
Input:
  > 'img_areas' - the list of image area dictionaries returned from 'process_wsi_annots'
    for the slide image.
  > 'tile' - a shapely Polygon representing the rectangle of the tile.
  > 'intersection_part' - for a tile to be considered relevant and associated with
    an image area, at least 'intersection_part' of its area should be contained
    in it (and they should intersect). Default: 0.5.

Return the index of the matching image area in the img_areas list, or -1 if the
  rectangle does not belong to any relevant area of the image.
'''
def locate_tile(tile, img_areas, intersection_part=0.5):

    # calculate minimum intersection area required
    min_intersect_area = intersection_part * tile.area

    # check every image area
    for i, img_area in enumerate(img_areas):
        if tile.intersects(img_area['boundary']) and \
          tile.intersection(img_area['boundary']).area >= min_intersect_area:
            return i

    # if we reached here - the rectangle does no intersect with any area
    return -1

After we located the relevant tiles, our models will treat them as separate samples of data. The most important part of the data is the cellmarks - but how will they come into play after tiling? There are thousands of markers placed all over the image - most of them are obviously not relevant for some small patch.

The markers that are relevant for a tile are the ones that are presented in it - which are the ones that intersect with it. So we will just check which markers intersect with the tile, and attach a list of them to it. We only need to notice that since tiles will be considered as separate data samples, the coordinates of the markers should be shifted relative to the tile's location on the image.

In [ ]:
'''
Extract the cell-marks that appear on a certain tile on a whole-slide image.
Input:
  > 'img_area' - the image area dictionary representing the image area associated
    with the tile (as returned from 'process_wsi_annots').
  > 'tile' - a shapely Polygon representing the rectangle of the tile.
  > 'tile_x' & 'tile_y' - the coordinates of the top left corner of the tile (re-
    calculated if they're None). Default: None.

Returns a list of pairs of cell-mark edge vertices that appear in the tile, each
  represented by a tuple ((x1,y1), (x2,y2)) of vertix coordinates.
Note - some of the cell-mark edge vertices coordinates may be negative (if part
  of the cell-mark is off the tile).
'''
def get_tile_cellmarks(img_area, tile, tile_x=None, tile_y=None):
    # recalculate top-left corner if necessary
    if tile_x is None:
        tile_x = min([vertix[0] for vertix in tile.exterior.coords])
    if tile_y is None:
        tile_y = min([vertix[1] for vertix in tile.exterior.coords])

    # initialize list of cell-marks to return
    tile_cell_marks = []

    # check every cell-mark associated with that image area
    for cell_mark in img_area['cellmarks']:
        if tile.intersects(cell_mark):
            # shift marker location and add it
            x1 = cell_mark.coords[0][0] - tile_x
            y1 = cell_mark.coords[0][1] - tile_y
            x2 = cell_mark.coords[1][0] - tile_x
            y2 = cell_mark.coords[1][1] - tile_y
            tile_cell_marks.append(((x1, y1), (x2, y2)))

    return tile_cell_marks

Let's make sure everything works well by using some arbitrary tile:

In [ ]:
# create some arbitrary tile
x = 30800; y = 17500; w = 500; h = 500
tile = tile2polygon(x, y, w, h)
tile_img = img_handle[y : y+h, x : x+w]

# process it
tile_area = locate_tile(tile, img_areas)
tile_cellmarks = get_tile_cellmarks(img_areas[tile_area], tile)

ax = plt.subplot()
ax.set_title('Manually Extracted Tile')
DSProjectUtils.plot_tile(ax, tile_img, tile_cellmarks)
None

That looks just right! Now we can work on an end-to-end function to extract tiles from an image:

In [ ]:
'''
Processes a whole-slide image, and extracts tiles from it. Only tiles withing relevant
  areas of the image are considered.

Input:
  > 'img_areas' - the list of image area dictionaries returned from 'process_wsi_annots'
    for the source slide-image.
  > 'img_shape' - shape of the source slide-image.
  > 'tile_shape' - shape of the tiles to extract.
  > 'intersection_part' - for a tile to be considered relevant and associated with an
    image area, at least 'intersection_part' of its area should be contained in it (and
    they should intersect). Default: 0.5.

Returns a list of tiles, each represented by a dictionary with 2 entries:
  > 'top_left_corner' - a pair of coordinates for the top-left corner of the tile.
  > 'cellmarks' - a list of cellmarks, each represented by a tuple ((x1,y1),(x2,y2))
    of vertix coordinates.
'''
def extract_tiles(img_areas, img_shape, tile_shape, intersection_part=0.5):
    # extract some basic info for ease
    img_h, img_w = img_shape
    tile_h, tile_w = tile_shape

    # initialize output lists
    tiles = []

    # check each possible tile corner
    for tile_x in range(0, img_w-tile_w, tile_w):
        for tile_y in range(0, img_h-tile_h, tile_h):
            tile = tile2polygon(tile_x, tile_y, tile_w, tile_h)
            tile_area = locate_tile(tile, img_areas, intersection_part)

            # add it if it's relevant
            if tile_area != -1:
                tile_cellmarks = get_tile_cellmarks(img_areas[tile_area],
                                                    tile, tile_x, tile_y)
                tiles.append({'top_left_corner' : (tile_x, tile_y),
                              'cellmarks' : tile_cellmarks})

    return tiles

Let's check it with our sample:

In [ ]:
# extract tiles of some arbitrary size
start_time = time.time()
h = 512; w = 512
tiles = extract_tiles(img_areas, img_handle.shape[:2], (h,w))
# calculate and inform the user about the time elapsed
time_elapsed = time.time() - start_time
print(f'{len(tiles)} tiles extracted in {time_elapsed:.2f}s')
2846 tiles extracted in 4.16s
In [ ]:
# display some of the extracted tiles
fig, axes = plt.subplots(1,3, figsize=(18,6))
tile_idxs = (501,1758,958)

for tile_idx, axis in zip(tile_idxs, axes):

    # extract tile info
    tile = tiles[tile_idx]
    x, y = tile['top_left_corner']
    tile_img = img_handle[y : y+h, x : x+w]
    tile_cellmarks = tile['cellmarks']

    # plot it
    DSProjectUtils.plot_tile(axis, tile_img, tile_cellmarks)
    axis.set_title(f'Tile {tile_idx}')

fig.suptitle('Automatically Extracted Tiles', fontsize=15)
None

Experimental Data Analysis¶

Done is an analysis of the xml data: the properties of the cells.

Mainly, the diameter-length feature is examined, as it is a target of the system. Other properties are less important - the most important are the lengths and the images.

First create a dataframe containing cells info.

In [ ]:
# load xml file and fetch lines (=cells) records
XML_PATH='samples/ID12718.xml'
file = minidom.parse(XML_PATH)
regions = file.getElementsByTagName('Region')
lines = [region
         for region in regions
         if region.attributes['Type'].value is '4']

Create procedures for fetching line attributes. First, the target diameter points' coordinates.

In [ ]:
# returns [X0,Y0,X1,Y1] diameter points of line number i
def get_labels(i):
    line = lines[i]
    # access relevant attributes through xml tree
    vertices = line.childNodes[3].childNodes

    X0 = int(float(vertices[1].attributes['X'].value))
    Y0 = int(float(vertices[1].attributes['Y'].value))
    X1 = int(float(vertices[3].attributes['X'].value))
    Y1 = int(float(vertices[3].attributes['Y'].value))

    return [X0, Y0, X1, Y1]

Then the rest of the data of a cell.

In [ ]:
# returns [attr_names],[attributes] of line i, where attributes is all line data except its diameters points
def get_aux_data(i):
    line = lines[i]
    pairs_name_val=line.attributes.items()
    attr_names = [attr[0] for attr in pairs_name_val]
    attr_val = [attr[1] for attr in pairs_name_val]
    return attr_names,attr_val

Now combine the two to get the entire data of a cell.

In [ ]:
# returns whole data of a line as [attributes],[values]
def get_line(i):
    labels=get_labels(i)
    attr_names,attr_vals=get_aux_data(i)
    values=attr_vals+labels
    names=attr_names+['X0','Y0','X1','Y1']
    return names,values
In [ ]:
# see output
print(get_labels(5),get_aux_data(5), get_line(5), sep='\n')
[10084, 36034, 10102, 36038]
(['Id', 'Type', 'Zoom', 'Selected', 'ImageLocation', 'ImageFocus', 'Length', 'Area', 'LengthMicrons', 'AreaMicrons', 'Text', 'NegativeROA', 'InputRegionId', 'Analyze', 'DisplayId'], ['268270', '4', '1', '0', '', '0', '18.4', '0.0', '9.3', '0.0', '', '0', '0', '0', '7'])
(['Id', 'Type', 'Zoom', 'Selected', 'ImageLocation', 'ImageFocus', 'Length', 'Area', 'LengthMicrons', 'AreaMicrons', 'Text', 'NegativeROA', 'InputRegionId', 'Analyze', 'DisplayId', 'X0', 'Y0', 'X1', 'Y1'], ['268270', '4', '1', '0', '', '0', '18.4', '0.0', '9.3', '0.0', '', '0', '0', '0', '7', 10084, 36034, 10102, 36038])

Now use the procedures to create a pandas dataframe holding cell information.

In [ ]:
all_data=[]
ln_names=None
for i in range(len(lines)):
    ln_names,ln_vals=get_line(i)
    all_data+=[ln_vals]
frame=pd.DataFrame(all_data,columns=ln_names)
In [ ]:
# see frame
frame.head()
Out[ ]:
Id Type Zoom Selected ImageLocation ImageFocus Length Area LengthMicrons AreaMicrons Text NegativeROA InputRegionId Analyze DisplayId X0 Y0 X1 Y1
0 267376 4 1 0 0 23.4 0.0 11.8 0.0 0 0 0 2 10444 25229 10429 25247
1 267377 4 1 0 0 9.2 0.0 4.6 0.0 0 0 0 3 10403 25628 10401 25637
2 267378 4 1 0 0 20.2 0.0 10.2 0.0 0 0 0 4 10614 26060 10603 26077
3 267379 4 1 0 0 20.4 0.0 10.2 0.0 0 0 0 5 10336 26143 10316 26147
4 267380 4 1 0 0 38.5 0.0 19.3 0.0 0 0 0 6 11913 23764 11878 23780

Look at the attributes data types, as these are important for analysis procedures.

In [ ]:
print(frame.dtypes)
Id               object
Type             object
Zoom             object
Selected         object
ImageLocation    object
ImageFocus       object
Length           object
Area             object
LengthMicrons    object
AreaMicrons      object
Text             object
NegativeROA      object
InputRegionId    object
Analyze          object
DisplayId        object
X0                int64
Y0                int64
X1                int64
Y1                int64
dtype: object

Many of the attributes are of 'object' type. Change relevant columns (ID, length and length-microns) to int or float, for later analysis.

In [ ]:
# change object to float

for col in ['Id','Length','LengthMicrons']:
    frame[col]=pd.to_numeric(frame[col])

# inspect the change
print(frame.dtypes)
Id                 int64
Type              object
Zoom              object
Selected          object
ImageLocation     object
ImageFocus        object
Length           float64
Area              object
LengthMicrons    float64
AreaMicrons       object
Text              object
NegativeROA       object
InputRegionId     object
Analyze           object
DisplayId         object
X0                 int64
Y0                 int64
X1                 int64
Y1                 int64
dtype: object

The left object columns are irrelevant for our task. We now examine basic statistics of the relevant attributes.

In [ ]:
frame.describe() # describes only non-object (=relevant) columns
Out[ ]:
Id Length LengthMicrons X0 Y0 X1 Y1
count 3.182000e+03 3182.000000 3182.000000 3182.000000 3182.000000 3182.000000 3182.000000
mean 6.061666e+05 25.343746 12.718730 19633.128221 21463.914205 19631.634507 21479.765556
std 3.838648e+05 14.603199 7.330952 11841.655905 9362.963460 11840.656573 9362.826031
min 2.673760e+05 8.000000 4.000000 1306.000000 507.000000 1298.000000 524.000000
25% 2.692052e+05 16.000000 8.000000 10345.750000 13248.000000 10340.250000 13262.750000
50% 2.788815e+05 20.200000 10.200000 16864.500000 23317.500000 16874.500000 23336.500000
75% 1.007349e+06 28.600000 14.400000 28982.250000 29873.250000 29027.500000 29881.750000
max 1.209295e+06 99.000000 49.700000 46398.000000 36578.000000 46401.000000 36592.000000

Note the relationship between Length and LengthMicrons: LengthMicrons $\approx$ Length/2. For example see their corresponding means.

Note the exact constant isn't 1/2 but 0.5019, as previously mentioned.

Furthermore, note that the coordinates Xi and Yi are relative to the whole, $\sim$4Gb image, and not to specific patches.

Note: the whole process of processing the XML into its current form is encapsulated in a function inside the utils python file of the project.

Plot the distributions of the targets: lengths and length microns.

In [ ]:
fig, axs = plt.subplots(ncols=2)
sns.histplot(frame['LengthMicrons'], ax=axs[0])
sns.histplot(frame['Length'], ax=axs[1])
plt.show()

We see that most cells have diameters lengths below 50. Note the exceptionals above 75.

Induced by the relationship between the two attributes, the LengthMicrons distribution is a squeezed version of the Length distribution.

Also note the peak at around 10ums, and the decay on its right. This behavior is abundant in WSI images.

Throughout the project we exclusively use the LengthMicrons as a target, because this is the metric used by the target user.


Processing the Cells¶

In this section, we'll understand the appearance of CR+ neurons in our samples, and use various image processing techniques to process them. The target goal of this section is to automatically extract segmentation maps of the neurons (assuming that we know their center point - which we'll work on in the next section), and use those to identify the cells' diameter points (i.e. ends of the diameter line).

Our process will be composed of several steps:

  • We'll first understand the appearance of the CR+ cells, and extract binary segmentation maps for what appears to look like ones.

  • Then, given the center point of a known cell, we'll extract a dedicated segmentation map for it. We will then use this map to calculate the cell's contour, in search of the two most distant points on it - the ends of the diameter. We can use these to automatically mark the cell.

  • Finally, we'll evaluate our method by comparing the automatically-produced cellmarks to those that were placed by Doctor Kelmer's research team.

The Appearance of CR+ Neurons¶

To come up with a technique to extract the segmentation maps, we first note several characteristics of the cells.

Cells have three general characteristics:

  1. They are dark.
  2. They may have little purple blobs adjacent to them. The blobs aren't considered a part of the cell, as dictated by Dr. Kelmer. [See left fig.]
  3. They may have several thin dark 'strips' stretching out of them (axons), as if they were balls with ropes sticked to them by their ends. These are also not considered a part of a cell. [See right fig.]
Figure 12: To the left - a cell with purple blobs; to the right - a cell with strips.

Because of the dark narture of the cells, simple thresholding yields an approximate segmentation of them. Then a contour detection may be carried out, from which we detect the diameter, as the two farthest apart points on the contour.

This technique doesn't taken into account the other two characteristics. This is crucial, as the purple regions may actually be larger than the cell itself, and the strips stretch the cell, both of which dramatically increase the diameter of the detected contour.

Therefore, when creating a segmentation of the cell, purple blobs and dark strips should be excluded. The purples will be removed by color filtering prior to image gray-scaling, and the strips through morphological opening applied on the detected segmentation.

Given these considerations, the following procedure has been determined: given an RGB image containing a cell at its center,

  1. Remove specific kinds purple pixels from the image. Done because there often are dark-enough purples adjacent to the cell that pass thresholding. Because of their proximity, the purple regions will be contained inside of the connected-component of the cell, despite them not being a part of the cell. That's why purples are removed before thresholding.
  2. Convert the image to grayscale. So thresholding may be applied.
  3. Apply a negative-thresholding on the image, yielding a binary image that preserves dark intensities. The cell is dark, and such action would only keep cell pixels active. The result is a H by W binary image with 0s and 1s representing an initial cell segmentation.
  4. Make a small neighborhood of the cell-center have values of 1. It is quite common that the cell-center itself isn't of high enough intensity to pass the threshold. But its neighbors probably will pass the threshold, and thus we set the cell-center's value to 1. Because we calculate connected components this may not be enough. When the neighborhood of the cell doesn't pass thresholding the connected component containing the center is only itself. Thus we extend the 1-setting to a small neighborhood of the center, so it may fuse between the cell's connected component and the rest of the cell's component.
  5. Fill holes in the image (using dilations). To make the connected components 'whole'. Holes inside connected components create an inner-contour of the connected component, other than the wanted outer-contour of the cell (imagine a circle with an inner-circle-hole, the contour of this shape is the union of the two circumferences).
  6. Apply opening on the image, with a small enough structuring element (s.t. the cell-center survives the action). In order to extinguish possible 'strips' coming out of the cell, we apply the opening morphological operation. Thin anough strips would get removed. Opening may also remove the cell-center, so we choose a structuring element (the hyperparameter of opening) that keeps it active. No matter what we do, the cell-center should stay active.
  7. Get the connected component containing the cell-center, which is the desired image-centered cell segmentation.
  8. Get the contour of the connected component (=wanted cell) through complement-of-erosion.
  9. Find the two farthest apart points on the contour and return them as the diameter ends.

Notes:

  • Images given to the method are not normalized, and won't get normalized in any step of the process.
  • cv2 loads images in BGR format by default, and not RGB. Legacy 'feature'.

Creating the Dataset¶

Create a dataset solely for the use of this part. It will return only patches containing a cell.

In [ ]:
'''
This class represents a dataset of patches of whole-slide-images, centered on the
  different cellmarks it contains.
Each element in the dataset is a pair (patch,cellmark) where:
  > 'patch' is a numpy array of shape (3,h,w) containing a patch of the WSI centered
    on some cellmark.
  > 'cellmark' is a tuple ((x1,y1),(x2,y2)) of the cellmark vertices coordinates.
'''
class IHCCellsDataset(Dataset):
    '''
    Constructs a new cells dataset.
    Input:
      > 'img_handle' - opened pytiff handle for the tiff containing the WSI.
      > 'img_areas' - list of image area dictionaries (as returned from
        'process_wsi_annots').
      > 'patch_shape' - the shape of the image patches to produce.
    '''
    def __init__(self, img_handle, img_areas, patch_shape):
        # store patch dimensions and tiff handle
        self.patch_height, self.patch_width = patch_shape
        self.img_handle = img_handle

        # collect the coordinates of all cellmarks
        self.cellmarks = sum([img_area['cellmarks'] for img_area in img_areas], [])

    '''
    Returns the number of elements in this dataset.
    '''
    def __len__(self):
        # the number of elements in the dataset is the number of cellmarks in the image
        return len(self.cellmarks)

    '''
    Gets an element from this dataset.
    Input - index of the element.
    The format of the returned element is described in the dataset's documentation
      above.
    '''
    def __getitem__(self,  i):
        # collect the edge-vertices of the cell-mark, and calculate their average
        vertex1, vertex2 = self.cellmarks[i].coords
        mid_point = ((vertex1[0]+vertex2[0]) / 2, (vertex1[1]+vertex2[1]) / 2)

        # calculate top-left corner of the patch
        top_left_x = int(mid_point[0] - self.patch_width/2)
        top_left_y = int(mid_point[1] - self.patch_height/2)

        # calculate relative coordinates of edge-vertices of the cell mark
        rel_vertex1 = (vertex1[0] - top_left_x, vertex1[1] - top_left_y)
        rel_vertex2 = (vertex2[0] - top_left_x, vertex2[1] - top_left_y)

        # extract patch
        patch = self.img_handle[top_left_y : top_left_y + self.patch_height,
                                top_left_x : top_left_x + self.patch_width]

        return patch, (rel_vertex1, rel_vertex2)

The used patch size is 110, and thus the center of a patch has coordinates (55,55) relative to the patch.

Create a dataset object:

In [ ]:
percell_dataset = IHCCellsDataset(img_handle, img_areas, (110,100))
print(f'Putamen sample at {tiff_path} contains {len(percell_dataset)} CR+ neurons.')
Putamen sample at samples/ID12718.tiff contains 3182 CR+ neurons.

We see there are 3182 marked cells in the image.

We now implement each step in the process described above.

Steps Implementation¶

Removing Purples¶

The filtering of purples is done through the filtering of pixels according to their Hue.

The Hue is a measure of color that encapsulates properties relevant to our use.

We first create a procedure that finds the purples in an image.

In [ ]:
'''
Input: img - RGB image in numpy of shape HxWx3, of dtype uint8
Output: binary (True [1] /False [0]) ndarray of shape HxW.
        An item is True if the respective pixel in img is purple.
'''
def get_purples(img):
    # Open image as PIL Image and make RGB and HSV versions
    RGBim=Image.fromarray(img).convert('RGB')
    HSVim = RGBim.convert('HSV')

    # Make numpy versions
    RGBna = np.array(RGBim)
    HSVna = np.array(HSVim)

    # Extract Hue
    H = HSVna[:,:,0]

    # Find all purple pixels, i.e. where 270 < Hue < 350
    lo,hi = 270,350
    # Rescale to 0-255, rather than 0-360 because we are using uint8
    lo = int((lo * 255) / 360)
    hi = int((hi * 255) / 360)
    # get hues inside the range
    purple = np.where((H>lo) & (H<hi))
    # create mask of purples
    mas=np.zeros((RGBna.shape[0],RGBna.shape[1]))
    mas[purple]=1
    return mas>0

We now use the procedure to remove purple pixels from an image by setting them to be white. That way, the purples will become white in the gray-scaled image, and thus won't pass thresholding. Hence, they will get discarded.

In [ ]:
'''
Input: img - RGB image in numpy of shape HxWx3
Output: img with its purple pixels replaced with white pixels.
        I.e., numpy RGB image of the same shape.
'''
def remove_purples(img):
    mas=get_purples(img) # get mask
    cp=img.copy() # don't change the original
    cp[mas]=[255,255,255] # replace purples with white
    return cp

Now see an example of removing the purple pixels from a colorful image.

In [ ]:
# example
circ=cv2.imread('circle.png') # loads in BGR
# convert to RGB, default is BGR
circ=cv2.cvtColor(circ, cv2.COLOR_BGR2RGB)


print('Original image:')
plt.axis('off')
plt.imshow(circ)
plt.show()

print('After purples removed:')
plt.axis('off')
plt.imshow(remove_purples(circ))
plt.show()
Original image:
After purples removed:

As can be seen, the purples in the image were removed and replaced with white pixels.

The excluded hue range includes a slightly wider range of colors than purples only, as a safety measure.

After removing purples, the image is converted to grayscale and a thresholding is carried out, which outputs a binary image.

Fetching Connected Components¶

Given a binary image, recognition of the connected components inside it is done, so the cell's component could be fetched.

In [ ]:
'''
Input: threshed - binary numpy array containing 0-1 values.
Output: image the same size of threshed, where each pixel is the index of its respective
        connected component.
        Class index of the background is 0.
'''
def get_connected_comps(threshed):
    blobs_labels = measure.label(threshed, background=0)
    return blobs_labels

And now we plot an example. Each connected component is colored differently.

In [ ]:
'''
Input: threshed - binary numpy array containing 0-1 values.
Operation: plots the connected components recognized inside threshed.
'''
def plot_conn(threshed):
    all_labels = measure.label(threshed)
    plt.figure(figsize=(9, 3.5))
    plt.subplot(121)
    plt.imshow(threshed, cmap='gray')
    plt.axis('off')
    plt.subplot(122)
    plt.imshow(all_labels, cmap='nipy_spectral')
    plt.axis('off')
    plt.show()
In [ ]:
# example
# load image
img,lbl=percell_dataset[5]
# convert to grayscale and threshold
img=cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img=(img<110)
print('Thresholded image & Detected Connected Components:')
plot_conn(img)
Thresholded image & Detected Connected Components:

As we see, the connected components were successfully detected.

Opening¶

After receiving the connected component of the cell, potential strips are removed by the use of an opening operation.

The opening operation is the dilation of the erosion of the component.

  1. The erosion looks at each pixel's neighbourhood (determined by a specific radius) and checks if all of it is contained inside the component. If so, it remains active and otherwise shut down. After this step, the strips are gone as they are thin and hence their pixels' neighbourhoods are not contained inside the component, the radius is too big. Moreover, after this step, the true connected component's contour and some pixels adjacent to them are chopped off.
  2. The dilation operation restores them. It looks at each pixel and checks if its neighbourhood (of the same radius as the one used in the erosion) intersects what's left of the component. If so it becomes active and otherwise remains the same. Note that dilation doesn't restore the pixels of the strips. That's because they are gone and the dilation only adds a strap around the eroded component.

Again note: strip pixels are thin, so a neighbourhood is larger than its width. Hence these pixels tend not to pass opening, and so we eliminate strips while preserving the contour of the cell's true connected component.

Further note that the erosion is an interior detector. Hence, the complement of an eroded component w.r.t the component is the contour of it.

In [ ]:
'''
Input: image - binary numpy image
       size - size of the structuring element
Output: opening of the image by a circle with diameter of size 'size' as a structuring element.
        numpy binary array of the shape of image.
'''
def do_opening(image,size):
    half=int(size/2.) # make circle centered
    circle_image = np.zeros((size, size))
    circle_image[disk((half,half), half)] = 1
    image=opening(image, circle_image)
    return image

Auxiliary¶

Some auxiliary functions are defined.

The first is used to find the two farthest points on the contour.

In [ ]:
'''
Input: xs - numpy array of x-coordinates of points, shape n
       ys - numpy array of y-coordinates of points, shape n
Output: the two points inside {(x1,y1),...,(xn,yn)} which are farthest apart.
        breaks ties arbitrarily.
        Return format: (p0,p1) where p0 & p1 are numpy arrays of size 2,
                        containing [x,y] coordinates.
'''

def get_max_dist(xs,ys):
    largest_dist=-1
    best_0=best_1=None
    for i in range(len(xs)):
        for j in range(len(xs)):
            p0=np.array([xs[i],ys[i]])
            p1=np.array([xs[j],ys[j]])
            d=(((p0-p1)**2).sum())**(1/2.)
            if d>largest_dist:
                largest_dist=d
                best_0=p0
                best_1=p1
    return best_0,best_1

The other to get the relevant connected component ID for the cell-center.

In [ ]:
'''
Inputs: image - integer image
        y,x - respective coordinates of a point inside of image.
Output: the nearest non-zero pixel value to the pixel at (x,y).
'''
def get_nearest_nonzero(image, y,x):
    dup=image.copy()
    dup[y,x]=0 # so we don't achieve output=(x,y)
    ys,xs=np.where(dup!=0)
    dists_sq=(xs-x)**2+(ys-y)**2 # distance-squared from all non-zero pixels
    idx=np.argmin(dists_sq) # minimum distance-squared is also minimum distance
    return dup[ys[idx],xs[idx]]

Final Procedure¶

Recall the procedure:

  1. Remove specific kinds purple pixels from the image. They're replaced with white pixels, so they don't pass thresholding.
  2. Convert the image to grayscale.
  3. Apply a negative-thresholding (i.e. apply Indicator[x\< threshold]) on the image, yielding a binary image that preserves dark intensities.
  4. Make a small neighborhood of the cell-center have values of 1.
  5. Fill holes in the image (using dilations).
  6. Apply opening on the image, with a small enough structuring element (s.t. the cell-center survives the action).
  7. Get the connected component containing the cell-center, which is the desired image-centered cell.
  8. Get the contour of the cell (i.e. the connected component) through the complement-of-erosion.
  9. Find the two farthest apart points on the contour and return them as the diameter endpoints.

The implementation follows.

First is a procedure that fetches the segmentation of a cell. We apply openings with neighbourhood sizes of 22, 18 and 12 pixels, taking one that preserves the center with preference for the largest radius opening.

In [ ]:
'''
Input: image - RGB numpy image containing a cell, of dtype uint8
        center - array [x,y] of the coordinates of the cell center inside image.
                  Defaults to [55,55].
                  If the size of the image isn't 110x110, center should be updated appropriately.
        open - boolean, controls whether to apply opening.
              When open=True and no fit (i.e. that doesn't extinguish the center) structuring element is found, opening isn't carried out.
Outputs: segmentation of the cell inside image, whose center is the one given.
         Format: np array of the size of image's width and height, containing 1s and 0s.
'''
def get_cell_segmentation(image, center=[55,55], open=True):
    # remove purples
    image=remove_purples(image)
    # convert to grayscale (we have RGB)
    image=cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # threshold image with empirically-proven threshold of 110
    threshed=image<110
    # fill neighberhood of center s.t. if he is in a hole, then he will merge
    # into the cell's connected component
    threshed[center[0]-5:center[0]+5,center[1]-5:center[1]+5]=True
    filled_holes=scipy.ndimage.binary_fill_holes(threshed) # fill holes so connected components are whole
    # do opening, choose largest size that keeps center active
    if open:
        big=do_opening(filled_holes, size=22)
        med=do_opening(filled_holes,size=18)
        small=do_opening(filled_holes,size=12)
        # apply only if the center survives
        if big[center[0],center[1]]>0:
            filled_holes=big
        elif med[center[0],center[1]]>0:
            filled_holes=med
        elif small[center[0],center[1]]>0:
            filled_holes=small
        # otherwise don't apply

    # get nearest connected component to center
    components_image = get_connected_comps(filled_holes)
    cell_class=get_nearest_nonzero(components_image, center[0],center[1])
    # keep only the center's connected component
    only_conn=filled_holes.copy()
    only_conn[components_image!=cell_class]=0
    return only_conn

The second runs the entire procedure: fetches the segmentation, calculates the contour and finds the diameter.

In [ ]:
'''
Input: image - RGB numpy image containing a cell, of dtype uint8
       center - array [x,y] of the coordinates of the cell center inside image.
                  Defaults to [55,55].
                  If the size of the image isn't 110x110, center should be updated appropriately.
       open - boolean, controls whether to apply opening.
              When open=True and no fit (i.e. that doesn't extinguish the center) structuring element is found, opening isn't carried out.
       segmentator - the segmentation function used. Changed as we next improve the method.
Outputs: prediction for the diameter-ends of the cell.
         Format: (p0,p1), where p0 & p1 are numpy arrays containing [x,y] coordinates of the ends.
'''
def get_diameter_points(image, center=[55,55], open=True, segmentator=get_cell_segmentation):
    only_conn=segmentator(image,center,open)
    # get contour of the component
    no_edges = erosion(only_conn)
    edges=only_conn^no_edges # complement
    ys,xs  = np.where(edges!=0) # get coordinates pixels that survived
    # return most-distant pair from contour
    return get_max_dist(xs,ys)

We now test the procedure on an example cell.

In [ ]:
# example
tst_img,_=percell_dataset[29]
p0,p1=get_diameter_points(tst_img)
print('Predicted diameter points:',(p0,p1))
plt.imshow(tst_img)
plt.scatter([p0[0],p0[1]],[p0[1],p1[1]])
plt.axis('off')
plt.show()
Predicted diameter points: (array([58, 44]), array([40, 68]))

The two blue points are the predicted diameter. Looks nice, it detects the strip and ignores it.

Next is a procedure for diagnosing each step of the process, plotting the output of each step.

In [ ]:
'''
A copy of the get_diameter_points function, that also plots every stage of computation.
'''
def diagnose(image, center=[55,55], open=True):

  fig = plt.figure(figsize=(13, 7))
  fig.suptitle('Diagnose Results', fontsize=15)
  plt.axis('off')
  NIMGS=6
  cols=4.
  rows=int(NIMGS/cols)+1


  ax=fig.add_subplot(rows,cols, 1)
  plt.axis('off')
  plt.imshow(image)
  ax.title.set_text('Original RGB Image')
  # remove purples
  image=remove_purples(image)

  ax=fig.add_subplot(rows,cols, 2)
  plt.axis('off')
  plt.imshow(image)
  ax.title.set_text('After Purples Removed')

  # convert to grayscale
  # first we have BGR. convert to gray
  image=cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

  ax=fig.add_subplot(rows,cols, 3)
  plt.axis('off')
  plt.imshow(image, cmap='gray')
  ax.title.set_text('Grayscaling Result')

  threshed=image<110 # threshold image


  ax=fig.add_subplot(rows,cols, 4)
  plt.axis('off')
  plt.imshow(threshed, cmap='gray')
  plt.scatter([55],[55])
  ax.title.set_text('Thresholding Result')

  # fill neighberhood of center s.t. if he is in a hole, then he will merge
  # into the cell's connected component (switch 5 with an appropriate width of a cell/2)
  threshed[center[0]-5:center[0]+5,center[1]-5:center[1]+5]=True
  filled_holes=scipy.ndimage.binary_fill_holes(threshed) # fill holes so connected components are whole (good for when the center doesnt pass thresholding
  # but its surroundings are
  if open:
      big=do_opening(filled_holes, size=20)
      med=do_opening(filled_holes,size=18)
      small=do_opening(filled_holes,size=12)
      # apply only if  the center survives
      if big[center[0],center[1]]>0:
          filled_holes=big
      elif med[center[0],center[1]]>0:
          filled_holes=med
      elif small[center[0],center[1]]>0:
          filled_holes=small

  ax=fig.add_subplot(rows,cols, 5)
  plt.axis('off')
  plt.imshow(filled_holes, cmap='gray')
  ax.title.set_text('Opening Result')

  components_image = get_connected_comps(filled_holes)
  cell_class=get_nearest_nonzero(components_image, center[0],center[1])
  only_conn=filled_holes.copy()
  only_conn[components_image!=cell_class]=0
  no_edges = erosion(only_conn)
  edges=only_conn^no_edges

  ax=fig.add_subplot(rows,cols, 6)
  plt.axis('off')
  plt.imshow(edges, cmap='gray')
  ax.title.set_text('Detected Contour')

  ys,xs  = np.where(edges!=0)

  plt.show()
  return get_max_dist(xs,ys)

We test the procedure:

In [ ]:
# diagnose an image
diagnose(tst_img)
plt.show()

As seen above, the cell has a purple blob adjacent to it and a single strip stretching out of it.

The blob is removed by colour filtering, as shown in the second image: it became white.

The third image is after grayscaling, and the fourth is after thresholding.

Notice that the purple blob didn't pass thresholding. That's because it was set to be white.

The fifth image shows the result of the opening, which wipes out the strip and smooths the contour of the connected component.

The last image shows the calculated contour of the cell segmentation, from which the diameter points were found.

Evaluation of Our Marking Method¶

Now we evaluate our method, first on a few examples and then on the whole dataset.

In [ ]:
print(f'The percell-dataset contains {len(percell_dataset)} samples.')
The percell-dataset contains 3182 samples.

Single-Cell Evaluation¶

Start by analyzing a few results by hand to get a sense of the performance. We do that by both plotting the results as well as comparing the diameter lengths.

Note: green is always used for ground-truth visualization, whereas orange is used for predictions.

First is a procedure calculating predicted diameter length.

In [ ]:
'''
Input: image - numpy RGB image, pixels in 0-255, dtype uint8
Output: predicted diamter length from get_diameter_points.
'''
def get_diameter_length(image, segmentator=get_cell_segmentation):
  p0,p1=get_diameter_points(image,segmentator=segmentator)
  return (((p0-p1)**2).sum())**(1/2.)

Second is a function that plots a prediction alongside the ground truth, one sample at a time. It is later extended to several samples at a time.

The plot will contain the cell with a line representing the predicted diameter, and another line representing the ground truth (if given).

For later extension, we give the option to not plot on screen immediately, so it could be used inside a sub-plot.

In [ ]:
'''
Input:
  image - numpy RGB image of dtype uint8, pixels in 0-255
  line - ground truth [(x0,y0),(x1,y1)] (for evaluation on dataset). May be None.
  fig - figure to plot inside of. May be None, in which case a new figure is used.
  show - is True, does plt.show(). Otherwise doesn't.
  segmentator - the cell segmentator to use.
Operation: plots the image, the predicted diameter (orange), and ground truth diameter (green) inside fig if given.
Returns: the figure containing the image, prediction and ground truth.
'''
def plot_pred(image, line=None, open=True, show=True, fig=plt.figure(figsize=(10,7)), segmentator=get_cell_segmentation):
  # predict diameter
  p0,p1=get_diameter_points(image, open=open, segmentator=segmentator)
  plt.imshow(image)
  # plot prediction
  plt.axis('off')
  plt.plot([p0[0],p1[0]], [p0[1],p1[1]], linewidth=3, color='orange', alpha=.7)
  if line is not None:
    # plot ground truth
    plt.plot([line[0][0],line[1][0]], [line[0][1],line[1][1]], linewidth=3, color='green', alpha=.7)
  if show:
      plt.show()
  return fig
<Figure size 720x504 with 0 Axes>

We now test the procedure.

In [ ]:
# evaluate a sample
tst_img,tst_label=percell_dataset[565]
tst_img=np.array(tst_img)
plot_pred(tst_img, tst_label)
plt.show()

The lines fit nicely, with the prediction being a bit longer than the ground truth while staying inside the cell.

Evaluation on the Whole Dataset¶

We evaluate the method on the whole dataset, using the previously defined procedures.

We are given with a IHCCellsDataset datasets, which returns only cell images from the WSI. We then use it to fetch all of the cells, and run the percell method on each one.

In [ ]:
'''
Inputs:
  dataset - IHCCellsDataset class instance to evaluate on.
Returns:
  two np lists: diameter-length predictions, diameter-length ground truth (as in dataloader).
  both of size len(dataset).
'''
def predict_dataloader(dataset, segmentator=get_cell_segmentation):
  # initialize returned arrays
  length_preds=[]
  truth_lengths=[]
  for bnum, batch in enumerate(tqdm(dataset)):
    # fetch and convert from tensor to ndarray
    img, label = np.array(batch[0]), np.array(batch[1])

    gt_points=np.array([label[0][0],label[0][1],label[1][0],label[1][1]])
    # calculate diameter lengths from labels (recall gt_points=[x0,y0,x1,y1])
    gt_length=((gt_points[0]-gt_points[2])**2+(gt_points[1]-gt_points[3])**2)**(1/2.)

    # evaluate the image
    plen=get_diameter_length(img,segmentator=segmentator)

    length_preds+=[plen]
    truth_lengths+=[gt_length]
  # returns predictions and ground-truths as ndarrays
  return np.array(length_preds),np.array(truth_lengths)

We now evaluate the percell on the known WSI.

In [ ]:
# run evaluation on the whole dataset
percell_predictions, percell_ground_truths = predict_dataloader(percell_dataset)
100%|██████████| 3182/3182 [05:13<00:00, 10.14it/s]

Results Analysis¶

Analysis is done with absolute errors for easy comprehension.

Furthermore, analysis is done in ums and not in pixels, as the target user uses ums for a measure.

We first fetch the absolute errors for the lengths predicted.

In [ ]:
# calculate absolute errors in length
abs_errors=np.abs(percell_ground_truths-percell_predictions)
MICRONS_PER_PIX=0.5019
# convert to microns
abs_errors_microns=abs_errors*MICRONS_PER_PIX
# convert to dataframe (for later sorting)
abs_frame=pd.DataFrame(abs_errors_microns, columns=['errors'])

Glimps at the absolute errors:

In [ ]:
sns.histplot(abs_errors_microns)
plt.show()

Notice that the majority of errors are below 5 ums.

We now examine classical statistical attributes.

In [ ]:
print('Absolute Error statistics, in um:\n')
pd.set_option('float_format', '{:f}'.format)
print('Variance: ',float(abs_frame.var()))
abs_frame.describe(percentiles=[.25,.50,.75,.90,.95])
Absolute Error statistics, in um:

Variance:  9.914963328114457
Out[ ]:
errors
count 3182.000000
mean 2.616310
std 3.148803
min 0.000000
25% 0.844639
50% 1.718503
75% 3.039958
90% 5.613110
95% 8.449492
max 32.805072

We see that 90% of predictions have absolute error lower than 5 ums and 75% have errors lower than 3 ums.

Further note the mean and std of the absolute errors: 2.6 and 3 respectively. This low mean and std indicate nice performance.

The following takes the concept further, by plotting the Cumulative Distribution Function (CDF) of the errors.

In [ ]:
# returns the percentage of samples having error below a given threshold
def percentage_below(errors,thresh):
  return float((errors<thresh).sum()/len(percell_dataset))
In [ ]:
# print selected values of the CDF
for e in [5,7,10,15,20]:
  print(f'Percentage below {e} abs error:', percentage_below(abs_errors_microns,e))

# plot the CDF
sns.displot(abs_frame, x="errors", kind="ecdf")
plt.show()
Percentage below 5 abs error: 0.8815210559396606
Percentage below 7 abs error: 0.9283469516027656
Percentage below 10 abs error: 0.9638592080452546
Percentage below 15 abs error: 0.9855436832181018
Percentage below 20 abs error: 0.9962287869264613

About 92% of the samples produced errors lower than 7 ums. As can be guessed from the previous analysis, the function gets to 0.9 quickly, at around 5 ums.

We further examine the error distribution, by zooming into the below-7 error zone, which contains 92% of the errors.

In [ ]:
# the majority of errors are under 7
# clip to that range to better view the results
cleaned_abs_err=abs_frame[abs_errors_microns<7]
sns.displot(cleaned_abs_err, x='errors')
plt.show()

We see that most errors reside on the left side of the graph, with a small number of samples being at the 6-7 range.

Now extend per-cell visualizations to several-cells visualization. An option for saving the figure is given (through which we evaluated many samples, saving many multi-cell results and inspecting them).

In [ ]:
'''
Inputs:
  indexes - np array containing indexes of samples inside dataset (as in dataset.get_nonempty_element)
  dataset - dataset of type IHCCellsDataset to predict from
  title - string, figure title to set
  savepath - string/None, path to save the figure to. Should have image extension (e.g. '.png'). May be None.
  segmentator - the cell segmentator function to use.
Operation:
  plots the predictions for the samples of indexes, with their ground truths. 4 predictions per row.
  if title is None, show the plot to the user. Otherwise, saves the figure to the file specified by the path 'savepath',
  and doesn't show the plot.
'''
def inspect_preds(indexes, dataset, title='', savepath=None, segmentator=get_cell_segmentation):
  # pyplot configurations
  fig = plt.figure(figsize=(25, 25))
  fig.subplots_adjust(hspace=0.4, top=.85)
  fig.suptitle(title, fontsize=15)
  plt.axis('off')
  cols=4
  rows=int(len(indexes)/cols)+1
  # go through indexes of indexes
  for idxidx in range(len(indexes)):
    i=indexes[idxidx]
    image,label=dataset[i] # get item specified by the index
    image,label=np.array(image), np.array(label)
    ax=fig.add_subplot(rows,cols, idxidx+1)
    f=plot_pred(image,label, show=False, fig=fig, segmentator=segmentator)
    ax.title.set_text(f'Cell {i}')
    f.show()
  if savepath is None:
    plt.show()
  else:
    plt.savefig(savepath,bbox_inches='tight')
  # prevent showing if savepath is None
  plt.ion()
  plt.close(fig)

We now make several evaluations.

In [ ]:
# show first 20 predictions
inspect_preds(np.array([0,10,20,30,40,50,60,70,80]), percell_dataset)

We see that the method performes well on most samples, having better results than the labels themselves.

We now sort the samples by their error magnitude.

That is done in order to inspect specific error ranges, so we may asses the severity of error magnitudes (some errors may be produced by better results than the ground truth).

In [ ]:
# sort errors
abs_err_sorted=abs_frame.sort_values(by='errors', ascending=True)
abs_err_sorted
Out[ ]:
errors
2296 0.000000
3025 0.000000
3004 0.000000
618 0.000000
694 0.000000
... ...
136 26.631316
742 27.247354
1310 27.422082
2475 32.032621
2649 32.805072

3182 rows × 1 columns

We now visualize predictions with error values in the middle of all of the evaluations.

In [ ]:
# dataset has 3182 samples
# visualize the medium-errored samples
inspect_preds(list(abs_err_sorted[2200:2220].index), percell_dataset)

We see that the model gets sometimes out of the boundaries of the cell. Many results indicate better than ground truth results (for example, see the second column). The others have a little smaller diameter (e.g. most bottom-left cell).

Thus the mid-error values are good in quality.

We now inspect the worst predictions.

In [ ]:
# visaulize the 20 most-errored samples
inspect_preds(list(abs_err_sorted[-20:].index), percell_dataset)

Examples for faults in the algorithm for specific cells are shown above. There are two causes for this malfunctioning. These causes are demonstrated below:

The first cell (top-left) is so un-dark that the cell didn't pass thresholding correctly.

The cell below that is a cell where the strips were not removed. This is because of the slenderness of the cell: all of the openings removed its center point. Thus opening wasn't executed and the strips remained active, and taken into account for the contour calculation.

In order to view evaluations at a bigger scale, we add a method to save such figures to storage.

The following saves predictions for samples that have errors within certain range. That enables us to asses the severity of different error magnitudes, as the labels are themselves volatile in accuracy.

Each such range contains many samples, so it's divided into batches (save one batch from a range at a time).

In [ ]:
'''
Inputs:
  minerr,maxerr - floats, specifies the range of error for considered samples. They define a sample set containing samples with error within the range.
  The set is divided into batches.
  batchno - number of batch to plot.
  sizebatch - size of each batch, number of samples to plot.
  basepath - string that ends with '/', the path of the base directory to operate inside of. All savings are done there.
Operation: gets a batch of size batchsize from the error-set defined by (minerr,maxerr).
        the batch is the ith batch inside the set, for i=batchno.
        Plots predictions+ground truths of these samples and saves them in a (possibly new) folder inside of basepath: minerr-maxerr/.
'''
def save_prediction_err_range(minerr,maxerr, batchno, sizebatch=16, basepath='error_visualizations/'):
  # get samples within the error range
  indexes=np.where(np.logical_and(abs_errors>=minerr,abs_errors<maxerr))[0]
  # get batch
  indexes=indexes[batchno*sizebatch:(batchno+1)*sizebatch]
  # create directory if doesn't exist
  dirpath=basepath+f'{minerr}-{maxerr}'
  if not os.path.isdir(dirpath):
    os.makedirs(dirpath)
  # save fig to this path
  savepath=dirpath+f'/{batchno}.png'

  title=f'Predictions for Range [{minerr},{maxerr}), batch {batchno}\n Green - Ground Truth\n Orange - prediction'
  inspect_preds(indexes, percell_dataset, title=title, savepath=savepath)

The code that applies the saving is below.

In [ ]:
# do 20 batches of each
# intervals=[(0,5),(5,7),(7,15)]
# NBATCHES=5
# for (m,M) in intervals:
#   for i in tqdm(range(NBATCHES)):
#     save_prediction_err_range(m,M,i)

Improving the Method¶

We now further examine the defects of our method, in order improve it.

Let us recall the most errored samples, so we may see what could be revised.

In [ ]:
# visaulize the 10 most-errored samples
inspect_preds(list(abs_err_sorted[-10:].index), percell_dataset)

We recognize two types of trouble-making (or 'strange') cells:

  1. Bright cells - cells that aren't dark.
  2. Small cells with large strips.

Although this observation is made based on this small sample of most-errored cells, further examination of the worst results indicates the same conclusion.

We now examine each troublemaker on its own, in search for a way to cope with each.

Troublemaker 1: Bright Cells¶

The following cell, cell number 2166, is an example of such troublemaker. This is the first cell in the above figure, which we will examine below.

In [ ]:
maker_bright,_=percell_dataset[2166]
In [ ]:
plt.axis('off')
plt.imshow(maker_bright)
plt.show()

As can be seen, the cell itself is bright, including its center. Therefore we bet that the cell doesn't pass the thresholding, which requires certain absolute grayscale levels.

In [ ]:
diagnose(maker_bright)
plt.show()

As we assumed the cell didn't pass thesholding.

Yet, we can still recognize the cell by eye, as it is darker than its surroundings.

Therefore we search for a way to strech the grayscale levels so that the threshold is less absolue but relative to the cells sorroundings.

We come up with a solution of histogram matching. Given two grayscale images, a source and reference, it transforms the source to a new image whose gray levels distribute similarly to the reference image levels, while reserving the contents of the source image.

We now look for a reference image that contains a wanted property: its cell is dark and differentiable from its background. The intuition is that this reference image has a gray-level distribution that will cause the source image to also have a distinguishable dark cell.

Such a cell could be taken from the best-errored samples. We select one that its dark cell takes roughly half of its small-patch.

In [ ]:
ref_hist,_=percell_dataset[113]
plt.axis('off')
plt.imshow(ref_hist)
plt.show()

Save the image for later use:

In [ ]:
plt.imsave('images/ref_img_hist.jpg', ref_hist)

We now carry out the histogram matching and see the results. Recall that the matching is done over the gray-scale images, so we need first to convert the source and reference to grayscale (recall that before grayscaling we first remove purples).

In [ ]:
src=remove_purples(maker_bright)
ref=remove_purples(ref_hist)
# convert to grayscale
# first we have BGR. convert to gray
src=cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
ref=cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)

matched=skimage.exposure.match_histograms(src, ref)
plt.imshow(matched,cmap='gray')
plt.axis('off')
plt.show()

We see that the transformed image's cell is dark. To see better we plot the original and the result image, alongsize the reference image and their respective histograms.

In [ ]:
figure, axis = plt.subplots(2, 3,figsize=(15,10))

# show grayscale images
axis[0, 0].axis('off')
axis[0, 0].imshow(src,cmap='gray')
axis[0, 0].set_title("Source GrayScale Image")

axis[0, 1].axis('off')
axis[0, 1].imshow(matched,cmap='gray')
axis[0, 1].set_title('Matched Image')

axis[0, 2].axis('off')
axis[0, 2].imshow(ref,cmap='gray')
axis[0, 2].set_title('Reference Image')

# show respective gray-level histograms
axis[1, 0].hist(src.reshape(-1))
axis[1, 0].set_title('Source Histogram')

axis[1, 1].hist(matched.reshape(-1))
axis[1, 1].set_title('Matched Histogram')

axis[1, 2].hist(ref.reshape(-1))
axis[1, 2].set_title('Reference Histogram')

# Combine all the operations and display
plt.show()

We see that the transformed image's histogram is very similar to the refence's histogram.

The important part of the histogram is its left hand part. In the original image the left hand (below 50) is empty, whereas in the transformed part it contains many pixels. This is the result of a selection that emphasizes black pixels, induced by the size of the cell in the reference image.

We see that the transformed image contains little holes inside the cell. Most of these will get filled by our method, as can be see below, with the continuation of the percell method.

In [ ]:
threshed=matched<110
center=[55,55]
# fill neighberhood of center s.t. if he is in a hole, then he will merge
# into the cell's connected component
threshed[center[0]-5:center[0]+5,center[1]-5:center[1]+5]=True
filled_holes=scipy.ndimage.binary_fill_holes(threshed) # fill holes so connected components are whole
plt.imshow(filled_holes)
plt.axis('off')
plt.show()

The rest of the unfilled pixels of the cells are in this case not important as they don't affect the connected component. This is an empirical assumption that we make.

This it how we tackle the bright cells. The effect on bright cells is positive, and we now examine its effect on the cells we already tackle successfully, to see if it worsens the performance of this wide majority of samples.

We look at a representative example. Later, on the evaluation of the improved method, we give further justification for the performance by applying it on the whole dataset.

The type of cell representative for examination is one that is widely affected by the matching. These cells are ones with totally different histograms than the reference image.

Therefore these cells are either very large or very small. The affect on very small ones is more interesting, as the effect on large cells is the reduction of the number of dark pixels, which doesn't change the diameter significantly.

A cell answering these conditions is cell number 10.

In [ ]:
test_cell,_=percell_dataset[10]
plt.axis('off')
plt.imshow(test_cell)
plt.show()

We apply the same matching process and check the results.

In [ ]:
src=test_cell
ref=remove_purples(ref_hist)
  # convert to grayscale
  # first we have BGR. convert to gray
src=cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
ref=cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)

matched=skimage.exposure.match_histograms(src, ref)
figure, axis = plt.subplots(1, 3,figsize=(15,10))

# show grayscale images
axis[0].axis('off')
axis[0].imshow(src,cmap='gray')
axis[0].set_title("Source GrayScale Image")

axis[1].axis('off')
axis[1].imshow(matched,cmap='gray')
axis[1].set_title('Matched Image')

axis[2].axis('off')
axis[2].imshow(ref,cmap='gray')
axis[2].set_title('Reference Image')

# Combine all the operations and display
plt.show()

It can be seen that the cell in the matched image is still differentiable from the background. It is now connected to external parts, yet the connections are thin and will get eliminated by opening. Hence, the performance on this sample is kept.

We have seen that this improvement results in good performance on a bright cell, while preseving performance on a regular cell.

Troublemaker 2: Small Cells with Long Strips¶

We look at a representative of such cells, cell number 187.

In [ ]:
maker_small,_=percell_dataset[187]
plt.axis('off')
plt.imshow(maker_small)
plt.show()

We diagnose the method on it:

In [ ]:
diagnose(maker_small)
plt.show()

We see that around the blue dot on the thresholding result image, we put a white rectangle, that can be spotted in the opening result image.

Yet, opening wasn't applied. This can be induced by the fact that the strip crawling from the cell to the bottom of the image remained after the opening. All of the openings, with the varying sizes, have eliminated the center points and hence non has been applied.

The strip should be eliminated, as the prediction includes it as part of the contour, as can be seen below.

In [ ]:
plot_pred(maker_small)
plt.show()

This can be achieved by adding smaller openings. Such opening won't remove the center of smaller cells.

The smallest opening currently carried out is of radius 12px. We add multiple-sizes openings, down to a radius of 4px.

It also won't affect the processing of larger cells, as the largest center-preserving opening is the one applied.

Implementation¶

The improved part is the segmentator. Given the segmentation, countour and diameter calculation is the same as before.

The changes relative to the original method are marked with #. They are:

  1. We match the grayscaled image to our reference image before thresholding.
  2. We loop through more openings with varying sizes, the smallest being 4px radius opening.
In [ ]:
'''
Input: image - RGB numpy image containing a cell, of dtype uint8
        center - array [x,y] of the coordinates of the cell center inside image.
                  Defaults to [55,55].
                  If the size of the image isn't 110x110, center should be updated appropriately.
        open - boolean, controls whether to apply opening.
              When open=True and no fit (i.e. that doesn't extinguish the center) structuring element is found, opening isn't carried out.
        open_sizes - sized of openings to apply. These are applied in the order they are given,
                    with the first one that doesn't remove the center being the one applied.
        ref_hist_path - path of the RGB Cell image that will act as reference to histogram matching.
Outputs: segmentation of the cell inside image, whose center is the one given.
         Format: np array of the size of image's width and height, containing 1s and 0s.
'''
def new_get_cell_segmentation(image, center=[55,55], open=True, open_sizes=[18,12,8,4], ref_hist_path='images/ref_img_hist.jpg'):
    # get reference image for matching  #
    ref_hist = plt.imread(ref_hist_path) #
    # remove purples
    image=remove_purples(image)
    ref_hist=remove_purples(ref_hist) #

    # convert to grayscale (we have RGB)
    image=cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    ref_hist=cv2.cvtColor(ref_hist, cv2.COLOR_BGR2GRAY)

    # carry out matching #
    match = skimage.exposure.match_histograms(image, ref_hist) #

    # threshold image with empirically-proven threshold of 110
    # apply on matched image #
    threshed=match<110  #

    # fill neighberhood of center s.t. if he is in a hole, then he will merge
    # into the cell's connected component
    threshed[center[0]-5:center[0]+5,center[1]-5:center[1]+5]=True
    filled_holes=scipy.ndimage.binary_fill_holes(threshed) # fill holes so connected components are whole
    # do opening, choose largest size that keeps center active
    if open:
      opening_result=filled_holes
      # go over all sizes #
      for s in open_sizes:
        curr_res=do_opening(filled_holes, size=s)
        # apply only if  the center survives
        # and if so continue
        if curr_res[center[0],center[1]]>0:
            opening_result=curr_res
            break
      # otherwise don't apply

    # get nearest connected component to center
    components_image = get_connected_comps(opening_result)
    cell_class=get_nearest_nonzero(components_image, center[0],center[1])
    # keep only the center's connected component
    only_conn=opening_result.copy()
    only_conn[components_image!=cell_class]=0
    return only_conn

Evaluation on the Whole Dataset¶

We run the same evaluation code as in the original method's evaluation, to see the change. We change 'segmentator' to the new segmentator.

In [ ]:
# run evaluation on the whole dataset
new_percell_predictions, new_percell_ground_truths = predict_dataloader(
    percell_dataset,
    segmentator=new_get_cell_segmentation)
100%|██████████| 3182/3182 [05:00<00:00, 10.59it/s]

Results Analysis¶

We run the same code as before, while comparing the results to the previous method.

In [ ]:
# calculate absolute errors in length
new_abs_errors=np.abs(new_percell_ground_truths-new_percell_predictions)

# convert to microns
new_abs_errors_microns=new_abs_errors*MICRONS_PER_PIX
# convert to dataframe (for later sorting)
new_abs_frame=pd.DataFrame(new_abs_errors_microns, columns=['errors'])
In [ ]:
# plot errors of old and new methods
plt.title('Absolute Errors')
sns.histplot(abs_errors_microns, color='blue', label='Old Errors')
sns.histplot(new_abs_errors_microns, color='red', label='New Errors')
plt.legend()
plt.show()

We see that the errors distribute more uniformly below 5. In addition, we see that above 10, there are more old errors than new errors, which indicates the improvement over the strange cells.

We now compute and compare statistics.

In [ ]:
print('-'*10, 'OLD BELOW', '-'*10)

print('Absolute Error statistics, in um:\n')
pd.set_option('float_format', '{:f}'.format)
print('Variance: ',float(abs_frame.var()))
print(abs_frame.describe(percentiles=[.25,.50,.75,.90,.95]))

print('-'*10, 'NEW BELOW', '-'*10)

print('Absolute Error statistics, in um:\n')
pd.set_option('float_format', '{:f}'.format)
print('Variance: ',float(new_abs_frame.var()))
print(new_abs_frame.describe(percentiles=[.25,.50,.75,.90,.95]))
---------- OLD BELOW ----------
Absolute Error statistics, in um:

Variance:  9.914963328114457
           errors
count 3182.000000
mean     2.616310
std      3.148803
min      0.000000
25%      0.844639
50%      1.718503
75%      3.039958
90%      5.613110
95%      8.449492
max     32.805072
---------- NEW BELOW ----------
Absolute Error statistics, in um:

Variance:  5.8490211390456315
           errors
count 3182.000000
mean     2.707217
std      2.418475
min      0.000000
25%      1.201390
50%      2.340190
75%      3.506857
90%      4.906447
95%      6.312497
max     29.786570

The more uniform distribution is noted here through a bigger mean error in the new method.

This comes at the advantage of a smaller std (2.4 compared to 3.1).

We further add that the 95th error-percentile at the new method is lower: 6.3 compared to 8.4. That's good!

We also see it in the following figure, a plot of the two CDFs:

In [ ]:
# print selected values of the CDF
print('OLD|NEW')
for e in [5,7,10,15,20]:
  print(f'Percentage below {e} abs error:', percentage_below(abs_errors_microns,e),'|',percentage_below(new_abs_errors_microns,e))

combined={'Old CDF':abs_errors_microns, 'New CDF':new_abs_errors_microns}

fig, ax = plt.subplots()

sns.ecdfplot(combined)
plt.show()
OLD|NEW
Percentage below 5 abs error: 0.8815210559396606 | 0.9050911376492772
Percentage below 7 abs error: 0.9283469516027656 | 0.9604022627278441
Percentage below 10 abs error: 0.9638592080452546 | 0.9827152734129478
Percentage below 15 abs error: 0.9855436832181018 | 0.994343180389692
Percentage below 20 abs error: 0.9962287869264613 | 0.997171590194846

We see an advantage for the new method between 5 and 10 absolute errors, where the new cdf exceeds the old cdf. In addition, in each printed comparison the new method achieves better percentages.

Now we examine the predictions on the worst-errored samples of the old method, using the new method. Recall that these errors were computed into abs_err_sorted.

In [ ]:
# visaulize the 10 most-errored samples
inspect_preds(list(abs_err_sorted[-10:].index),
              percell_dataset,
              segmentator=new_get_cell_segmentation)

We see generaly better results. For example, see the last cell (2649), which is a bright cell. Also see cells 742 and 1310, whose predictions fit (these are small cells with large strips).

Lets see what are the new method's worst errored samples.

In [ ]:
# sort errors
new_abs_err_sorted=new_abs_frame.sort_values(by='errors', ascending=True)
In [ ]:
# visaulize the 10 most-errored samples
inspect_preds(list(new_abs_err_sorted[-10:].index), percell_dataset, segmentator=new_get_cell_segmentation)

Most of these samples have many dark pixels inside of them. We assume that the faulty results are induced by the matching procedure, which reduces the amount of black pixels in these images.

Further, as can be seen in cell 136 or 56 (at the bottom), 'bridges' outside of the cells are considered a part of the cell. That's also because of the matching procedure.

Possible solutions for future work may include playing with the small-patch sizes, so that the vast majority of the black pixels in them will be inside the cell of the patch. For example, in cell 136 a smaller patch would remove the bottom-right black area and result with a better matching outcome.

This improved method better copes with strange cells, and it isn't perfect.


Cell Localization¶

The problem at hand is an instance of object detection - the problem of detecting instances of certain types of objects in image data. In our case - we need to localize CR+ neurons in patches of slide-images of putamen samples, and mark them along their longest diameter.

Naturally, object detection is one of the most central and well-known tasks in Computer Vision. It's been around for decades, and dozens of algorithms were developed to tackle it. While doing the research for this project, we stumbled across many of them:

Traditional algorithms, such as -

  • Viola & Jones detectors (P Viola & M Jones, 2001)

  • HOG detectors (N Dalal & B Triggs, 2005)

And more modern algorithms, based on the popular approach of deep learning -

  • R-CNN (R Girshick et al., 2014) and its variants:

    • Fast R-CNN (R Girshick, 2015)

    • Faster R-CNN (S Ren et al. 2015)

  • YOLO architectures (J Redmon et al., 2016)

And more. However, none seemed like an ideal fit for a stand-alone system. So, we developed some original ideas - strongly inspired by the algorithms we studied.

Before we begin, let's tile the image into patches of different resolutions. We'll use those patches throughout the section:

In [ ]:
# create small patches used for training (of resolution 256px X 256px)
small_patch_shape = (256,256)
small_patches = extract_tiles(img_areas, img_handle.shape[:2], small_patch_shape)

# create large patches used for visualization (of resolution 512px X 2048px)
large_patch_shape = (512,2048)
large_patches = extract_tiles(img_areas, img_handle.shape[:2], large_patch_shape)

We'll also use the same set of transformations on the loaded patches, so let's define these in advance:

In [ ]:
# transformation for extracted patches
trans = transforms.ConvertImageDtype(torch.float)

First Attempt: Heatmap Regression¶

The Idea and the Dataset¶

While studying a very different task - human pose estimation, we encountered an approach that seemed like a good fit to our problem - heatmap regression (A Bulat & G Tzimiropoulos, 2016).

The idea is to fit a model that will get an image patch as an input, and will generate a one-channel heatmap of the same resolution containing values between 0.0 and 1.0 - that can then be used to easily locate the centers of the CR+ neurons that we look for in the patch.

We wish that the heatmap could tell us what parts of the patch are likely to be the center of a cell, so we'll denote high values in the (ideal) heatmap as if the corresponding pixel in the source patch is likely to be the center of the cell, and vice versa with low values.

So if we know where the cells are in our training data, we just need to decide how exactly we want our heatmaps to look like - so we can fit a model to regress them. To get some nice and continuous heatmaps, we decided to express the cells in them by drawing the PDF of a 2-dimensional gaussian distribution (with some constant standard deviation in all directions) around the center of each cell (identified by the center of its known cellmark). We'll combine those together into a single map by taking the maximum of the PDFs. Finally, we'll standardize the resulting maps so the maximum entry in them is 1.0.

Figure 13: Generating a heatmap from a patch containing cellmarks. The standard-deviation used here for the gaussians' PDF is 100.0 (a large SD for visualization).

Let's write a PyTorch dataset to create those heatmaps, alongside the original input image tiles they represent:

In [ ]:
'''
This class represents a dataset for detecting multiple brain cells in IHC stained
  whole-slide-images (WSI), based on a generated heatmap of localized cells.
'''
class HeatmapDataset(Dataset):
    '''
    Constructs a new dataset.
    Input:
      > 'img_handle' - an opened pytiff handle object for the TIFF slide-image.
      > 'patches' - list of patch dictionaries, as returned from 'extract_tiles'.
        It is assumed that all patches have the same shape.
      > 'patch_shape' - shape of the patches in 'patches'.
      > 'gaussian_sd' - standard deviation for the gaussian-like distribution heatmaps
        produced around the center of each cell. Default: 5.0.
      > 'trans' - a callable that will be called on the image patches before
        returning them. Default: None (no transformation will be applied).
    '''
    def __init__(self, img_handle, patches, patch_shape, gaussian_sd=5.0, trans=None):
        # store relevant info
        patch_height, patch_width = patch_shape
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.img_handle = img_handle
        self.patches = patches
        self.gaussian_sd = gaussian_sd
        self.trans = trans

        # grid of 2D coordinates (shape (patch_height,patch_width,2)) on which we
        # apply the PDF of the gaussians to create the heatmaps
        self.patch_coords_grid = np.dstack(np.mgrid[0:patch_height:1, 0:patch_width:1])

    '''
    Returns the number of elements in this dataset.
    '''
    def __len__(self):
        # the number of elements is the number of patches
        return len(self.patches)

    '''
    Gets an element from this dataset.
    Input - index of the element.
    Returns a pair (img, target), where:
    > 'img' is a PyTorch tensor of shape (3,h,w) containing a patch of the tiff image,
      which may or may not contain cells.
    > 'target' is a PyTorch tensor of shape (1,h,w) with entries between 0.0 and 1.0.
      Most of the entries in it are 0.0, except for those around a cell in the image:
      around the center of each cell in it, a round gaussian-like distribution heatmap
      is added (of maximum value 1.0), with the specified constant SD.
    '''
    def __getitem__(self,  i):
        # extract image patch
        patch = self.patches[i]
        top_left_x, top_left_y = patch['top_left_corner']
        img = self.img_handle[top_left_y : top_left_y + self.patch_height,
                               top_left_x : top_left_x + self.patch_width]

        # calculate cell centers (center of the markers)
        cell_centers = [((x1+x2) / 2, (y1+y2) / 2)
                        for (x1, y1), (x2, y2) in patch['cellmarks']]

        # draw gaussians around them
        gaussians = [multivariate_normal(mean=[y,x], cov=self.gaussian_sd)\
                        .pdf(self.patch_coords_grid)
                    for (x,y) in cell_centers]

        # append zero gaussian map (so it's not empty)
        gaussians.append(np.zeros((self.patch_height, self.patch_width)))

        # combine them to a single heatmap using 'maximum', and rescale it
        target = np.maximum.reduce(gaussians)
        max = target.max()
        if max > 0.0: target /= max

        # prepare output
        img = torch.tensor(img).movedim(-1,0)
        target = torch.tensor(target).unsqueeze(0)
        if self.trans != None:
          img = self.trans(img)
          target = self.trans(target)
        return img, target

Let's check it out:

In [ ]:
# create a heatmap dataset (with large standard deviation for visualizations)
heatmap_dataset = HeatmapDataset(img_handle, small_patches, small_patch_shape,
                                 gaussian_sd=150.0)
print(f'Heatmap dataset:\n',
      f'\tPatch resolution: {small_patch_shape[1]}px X {small_patch_shape[0]}px\n',
      f'\tSize: {len(heatmap_dataset)} samples')
Heatmap dataset:
 	Patch resolution: 256px X 256px
 	Size: 11367 samples
In [ ]:
# display some samples from the dataset
sample_idxs = [60,987]

for sample_idx in sample_idxs:
    # extract and organize the sample for displaying
    img, target = heatmap_dataset[sample_idx]
    img = img.moveaxis(0,-1).numpy()
    target = target.squeeze().numpy()
    # display its contents
    fig, axes = plt.subplots(1,2, figsize=(7,7))
    axes[0].imshow(img)
    axes[0].set_title(f'Source Patch {sample_idx}')
    axes[0].axis('off')
    axes[1].imshow(target)
    axes[1].set_title(f'Target Heatmap {sample_idx}')
    axes[1].axis('off')

Generating Heatmaps: the U-Net Architecture¶

During the past decade, the field of Deep Learning was recognized as one of the most powerful and reliable approaches in Data Science, and in Computer Vision in particular. Deep Learning allows for some powerful models for generation of image-like data, thanks to the transposed-convolution operation that can be embedded into deep models of decoder architectures. It is only natural to take this approach when designing an algorithm (or in our case - a model) to generate the required heatmaps.

We explored various options, and figured that a good fit for our problem is the popular U-Net architecture (O Ronneberger et al., 2015), which was proved to be a powerful concept for problems such as image segmentation and denoising. The general idea is to use a simple fully-convolutional encoder-decoder structure, but apply residual connections between matching pairs of an encoder layer and a decoder layer (by concatenating the output of each encoder layer to the input of the corresponding decoder layer over the channel dimension). It's called a U-Net since a block-diagram of the architecture has the shape of a large "U".

We'll use a slightly modified version of the network that was proposed in the original paper, one that uses padding in its convolutions to avoid cropping the output of the encoder layers to match the input shape of the decoder layers (and thus lose important information around the edges of the input patches). We also added dropout to each decoder layer.

Since our task is regressing heatmaps of values between 0.0 and 1.0 - the final layer of our network will have a 1-channel output followed by sigmoid activation (whose range is $[0,1]$).

All in all, this is the final architecture we designed for our purpose:

Let's start by implementing the up-blocks and down-blocks that form the U-Net:

In [ ]:
'''
This class represents a "down" block of a UNet architecture.
'''
class UNetDownBlock(nn.Module):
    '''
    Constructor: creates a new UNet "down" block.
    Input:
      > 'in_channels' - number of input channels.
      > 'out_channels' - number of output channels.
      > 'kernel_size' - size of the convolution kernels used. Default: 3.
      > 'down_factor' - dimension and stride for the down-sampling operation.
        Default: 2.
    The block architecture is:
      [in_channels] -> Conv2d + BatchNorm2d + LeakyReLU ->
      [out_channels] -> Conv2d + BatchNorm2d + LeakyReLU ->
      [out_channels] -> MaxPool2d or skip-connection
    '''
    def __init__(self, in_channels, out_channels, kernel_size=3, down_factor=2):
        # call Module class constructor to initialize the Module
        super().__init__()

        # define forward-unit (through which the activations flow)
        self.forward_unit = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

        # define down-sampling unit
        self.downsample_unit = nn.MaxPool2d(kernel_size=down_factor,
                                            stride=down_factor)

    '''
    Forwards a batch of samples through the block.
    Input:
      > 'x' - tensor of shape (N,in_channels,h,w) representing the input batch.
    Returns a pair (skip_con, output), where:
      > 'skip_con' is the output of the block, before downsampling. I.e. -
        input to the corresponding skip connection of the UNet.
      > 'output' is the downsampled output of the block.
    '''
    def forward(self, x):
        # feed through forward unit
        skip_con = self.forward_unit(x)
        # downsample
        output = self.downsample_unit(skip_con)
        return skip_con, output
In [ ]:
'''
This class represents an "up" block of a UNet architecture.
'''
class UNetUpBlock(nn.Module):
    '''
    Constructor: creates a new UNet "up" block.
    Input:
      > 'in_channels' - number of input channels.
      > 'skip_con_channels' - number of channels entered via the skip connection.
      > 'out_channels' - number of output channels.
      > 'kernel_size' - size of the convolution kernels used. Default: 3.
      > 'up_factor' - dimension and stride for the up-sampling operation.
        Default: 2.
      > 'dropout_prob' - probability for dropout of neurons at the beginning of
        the block. Default: 0.1.
    The block architecture is:
      [in_channels] -> ConvTranspose2d (upsampling) ->
      [in_channels] Concatenate with skip connection [skip_con_channels] along
        channel dimension ->
      [in_channels + skip_con_channels] -> Conv2d + BatchNorm2d + LeakyReLU ->
      [out_channels] -> Conv2d + BatchNorm2d + LeakyReLU
    '''
    def __init__(self, in_channels, skip_con_channels, out_channels, kernel_size=3,
                 up_factor=2, dropout_prob=0.1):
        # call Module class constructor to initialize the Module
        super().__init__()

        # define up-sampling unit
        self.upsample_unit = nn.ConvTranspose2d(
            in_channels, in_channels, kernel_size=up_factor, stride=up_factor
        )

        # define forward-unit (through which the activations flow)
        self.forward_unit = nn.Sequential(
            #nn.Dropout(dropout_prob, inplace=True),
            nn.Conv2d(in_channels + skip_con_channels, out_channels,
                      kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    '''
    Forwards a batch of samples through the block.
    Input:
      > 'x' - tensor of shape (N,in_channels1,h,w) representing the output of
        the previous block.
      > 'skip_con' - tensor of (N,in_channels2,h,w) representing the output of the
        skip connection.
      Where in_channels1+in_channels2=in_channels.
    Returns the output of the block.
    '''
    def forward(self, x, skip_con):
        # upsample
        x = self.upsample_unit(x)
        # concatenate with skip connection outputs along channel dimension
        x = torch.cat([x, skip_con], dim=1)
        # feed through forward unit
        output = self.forward_unit(x)
        return output

Now we can use those building blocks to implement a general U-Net architecture. We'll also add a built-in customized final layer (that can be changed for different tasks):

In [ ]:
'''
This class represents a UNet encoder-decoder architecture.
'''
class UNet(nn.Module):
    '''
    Constructor: creates a new UNet model.
    Input:
      > 'encode_blocks' - list of instances of UNetDownBlocks representing the encoding
        blocks of the network.
      > 'bottom_block' - bottom block of the network, having the same input and
        output shapes.
      > 'decode_blocks' - list of instances of UNetUpBlocks representing the decoding
        blocks of the network.
      > 'final_block' - final block of the network, processing the last output
        (here you can apply softmax or sigmoid activations for segmentation).
    '''
    def __init__(self, encode_blocks, bottom_block, decode_blocks, final_block):
        # call Module class constructor to initialize the Module
        super().__init__()

        # set the architecture
        self.encode_blocks = nn.ModuleList(encode_blocks)
        self.bottom_block = bottom_block
        self.decode_blocks = nn.ModuleList(decode_blocks)
        self.final_block = final_block

    '''
    Forwards a batch of samples through the network.
    Input:
      > 'x' - tensor of shape (N,in_channels,h,w) representing the input batch.
      > 'get_hist' - a boolean stating whether you wish to obtain the entire history
        of the network's in-between activations. Default: False.
    Returns the output of the network if 'get_hist' is False, or a list of the network
      activations if 'get_hist' is True.
    '''
    def forward(self, x, get_hist=False):
        # keep track of activation history if necessary
        if get_hist: hist = [x]

        # feed through encode layers and record skip connections
        skip_cons = []
        for encode_block in self.encode_blocks:
            skip_con, x = encode_block(x)
            skip_cons.insert(0, skip_con)
            if get_hist: hist.append(x)

        # feed through bottom block
        x = self.bottom_block(x)
        if get_hist: hist.append(x)

        # feed through decode layers
        for up_block, skip_con in zip(self.decode_blocks, skip_cons):
            x = up_block(x, skip_con)
            if get_hist: hist.append(x)

        # feed through final block
        output = self.final_block(x)
        if get_hist: hist.append(output)
        return hist if get_hist else output

Finally - we can construct our architecture:

In [ ]:
# construct a UNet for our purpose - heatmap regression
heatmap_regressor = UNet(
    encode_blocks =  [UNetDownBlock(in_channels=3, out_channels=16, kernel_size=7),
                      UNetDownBlock(in_channels=16, out_channels=32, kernel_size=7),
                      UNetDownBlock(in_channels=32, out_channels=64, kernel_size=5),
                      UNetDownBlock(in_channels=64, out_channels=128, kernel_size=3)],

    bottom_block =    UNetDownBlock(in_channels=128, out_channels=128, kernel_size=3).forward_unit,

    decode_blocks =  [UNetUpBlock(in_channels=128, skip_con_channels=128, out_channels=64, kernel_size=3),
                      UNetUpBlock(in_channels=64, skip_con_channels=64, out_channels=32, kernel_size=5),
                      UNetUpBlock(in_channels=32, skip_con_channels=32, out_channels=16, kernel_size=7),
                      UNetUpBlock(in_channels=16, skip_con_channels=16, out_channels=16, kernel_size=7)],

    final_block =    nn.Sequential(nn.Conv2d(in_channels=16, out_channels=8, kernel_size=5, padding='same'),
                                   nn.Conv2d(in_channels=8, out_channels=1, kernel_size=1, padding='same'),
                                   nn.Sigmoid())
)
heatmap_regressor = heatmap_regressor.to(device)

Let's calculate the number of learned-parameters that such a model is associated with:

In [ ]:
num_params = sum(param.numel() for param in heatmap_regressor.parameters())
print(f'Heatmap regressor contains: {num_params} parameters')
Heatmap regressor contains: 1265009 parameters

Fitting the Heatmap Regressor¶

Let's start by creating a dataset and a dataloader for our purpose. When experiementing with the data and the model, we figured that using small input patches of size 256px by 256px and heatmap gaussians with SD of 5.0 are a good fit.

In [ ]:
# create dataset & dataloader for heatmap regression
heatmap_dataset = HeatmapDataset(img_handle, small_patches,
                                 patch_shape=small_patch_shape,
                                 gaussian_sd=5.0, trans=trans)
heatmap_dataloader = torch.utils.data.DataLoader(
    heatmap_dataset, batch_size=64, num_workers=2, shuffle=True, pin_memory=True
)

Using this, we can write a function to fit our model:

In [ ]:
'''
Fits an encoder-decoder architecture.

Input:
  > 'model' - the model to fit.
  > 'optimizer' - the optimizer to train with; if the optimizer is attached to
    only some of the model's parameters, only those are trained (suitable for
    transfer learning).
  > 'train_dataloader' - dataloader for the training batches.
  > 'criterion' - the loss measure (callable with 2 parameters).
  > 'scheduler' - a scheduler object that will step every epoch.
  > 'epochs' - the number of training epochs. Default: 5.
  > 'rep_sample' - a representative sample of an input sample whose predictions will
    be recorded, or None if you don't wish to use this functionality. Default: None.

Returns a pair (loss_hist, pred_animation), wehere:
  > 'loss_hist' is a list containing the history of the batch losses during the
    training process.
  > 'rep_pred_hist' is a list containing the history of the predictions of the model
    on the representative sample (on CPU), or an empty list if 'rep_sample' is None.
'''
def train_endec_model(model, optimizer, train_dataloader, criterion,
                      scheduler=None, epochs=5, rep_sample=None):

    # lists to record the outputs
    loss_hist = []
    rep_pred_hist = []

    try:
        # record starting time
        start_time = time.time()

        # prepare representative sample
        if rep_sample is not None: rep_sample = rep_sample.unsqueeze(0).to(device)
        # record initial prediction of representative sample
        if rep_sample is not None:
            with torch.no_grad():
                model.eval()
                rep_pred_hist.append(model(rep_sample).cpu())

        for epoch in range(epochs):

            # Iterate over data batches
            for i, (input, target) in tqdm(enumerate(train_dataloader),
                                           f'epoch {epoch+1} / {epochs}',
                                           total=len(train_dataloader)):

                # move batch to the working device
                input = input.to(device)
                target = target.to(device)

                # forward propagation
                optimizer.zero_grad()
                model.train()
                prediction = model(input)
                loss = criterion(prediction, target)

                # backpropagation
                loss.backward()
                optimizer.step()

                # record batch loss and prediction on representative sample
                loss_hist.append(loss.item())
                if rep_sample is not None:
                    with torch.no_grad():
                        model.eval()
                        rep_pred_hist.append(model(rep_sample).cpu())

            # step the scheduler as an epoch was completed
            if scheduler != None: scheduler.step()

        # calculate and inform the user about the time elapsed
        time_elapsed = time.time() - start_time
        print(f'\nTraining completed in ' +
              f'{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

    # in case the user wished to stop training - just end it
    except KeyboardInterrupt:
        pass

    return loss_hist, rep_pred_hist

Finally, we can use it to fit our model. We'll use the Adam optimizer (simply because it worked best for us in the past), and the MSE loss measure (as this is a regression problem). We will also use exponential learning-rate decay, which improved performance by a noticeable margin:

In [ ]:
# prepare representative sample (that we picked in advance)
rep_sample, rep_target = heatmap_dataset[121]

# attach an optimizer and lr-decay scheduler, and fit the model using Adam and MSELoss
criterion = nn.MSELoss()
optimizer = optim.Adam(heatmap_regressor.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5**0.5)

loss_hist, rep_pred_hist = train_endec_model(heatmap_regressor, optimizer,
                                             heatmap_dataloader,
                                             criterion, epochs=10,
                                             rep_sample=rep_sample,
                                             scheduler=scheduler)

# save model and optimizer
torch.save(heatmap_regressor.state_dict(), 'models/heatmap_regressor')
torch.save(optimizer.state_dict(), 'optimizers/heatmap_regressor')
epoch 1 / 10: 100%|██████████| 178/178 [03:44<00:00,  1.26s/it]
epoch 2 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 3 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 4 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 5 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 6 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 7 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 8 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 9 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]
epoch 10 / 10: 100%|██████████| 178/178 [03:47<00:00,  1.28s/it]

Training completed in 37m 56s

Let's take a look at the progress of the model through time:

In [ ]:
# plot loss history (skip first 100 steps due to very sharp drops)
plt.plot(range(100, len(loss_hist)), loss_hist[100:])
plt.xlabel('Training Step')
plt.ylabel('MSE-Loss')
plt.title('Heatmap Regression Loss History (MSE)')
None
In [ ]:
'''
In this cell we animate the progress of the model on the representative sample
  through time
'''

# organize the data to show
rep_sample = rep_sample.moveaxis(0,-1)
rep_target = rep_target.squeeze()
rep_pred_hist = [rep_pred.squeeze() for rep_pred in rep_pred_hist]

# prepare figure
fig, axes = plt.subplots(1,3,figsize=(12,5))
for ax in axes: ax.axis('off')

# prepare plots
axes[0].imshow(rep_sample)
axes[0].set_title('Input Patch')
pred_axes_img = axes[1].imshow(rep_pred_hist[0], vmin=0, vmax=1);
axes[1].set_title('Predicted Heatmap')
axes[2].imshow(rep_target, vmin=0, vmax=1)
axes[2].set_title('Target Heatmap')

# add colorbar
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(pred_axes_img, cax=cbar_ax)

# prepare animation
def drawframe(n):
    pred_axes_img.set_data(rep_pred_hist[n])
    return [pred_axes_img]
anim = animation.FuncAnimation(fig, drawframe, frames=len(rep_pred_hist),
                               interval=20, blit=True)

# display animation (and only animation)
plt.close()
HTML(anim.to_html5_video())
Out[ ]:
Your browser does not support the video tag.

One perk of using a fully-convolutional architecture, is that it can be used on input images of any resolution - not just the resolution of the training data. Let's check it out on an instance of a dataset of larger patches:

In [ ]:
# build dataset of patcehs of an higher resolution (512,2048)
large_heatmap_dataset = HeatmapDataset(img_handle, large_patches,
                                       patch_shape=large_patch_shape,
                                       trans=trans, gaussian_sd=5.0)

# extract a patch from it (which we picked in advance)
patch_idx = 3
patch, target_heatmap = large_heatmap_dataset[patch_idx]

# evaluate the model on it
with torch.no_grad():
    heatmap_regressor.eval()
    output = heatmap_regressor(patch.unsqueeze(0).to(device))

# prepare data for visualization
pred_heatmap = output.cpu().squeeze()
patch = patch.moveaxis(0,-1)
target_heatmap = target_heatmap.squeeze()

# prepare figure
fig, axes = plt.subplots(3,1, figsize=(15,12))
for ax in axes: ax.axis('off')

# plot
axes[0].imshow(patch)
axes[0].set_title('Input Patch')
axes[1].imshow(pred_heatmap, vmin=0, vmax=1)
axes[1].set_title('Predicted Heatmap')
axes[2].imshow(target_heatmap, vmin=0, vmax=1)
axes[2].set_title('Target Heatmap')
plt.subplots_adjust(hspace=0.1)
None

Those results are not good enough. We can see that the model has a hard time localizing the centers of large CR+ cells - which are very important for the structure of putamen samples. This actually makes sense: since the neuron is large, the model gets confused when trying to mark its center (which can be placed on varoius points). So - it decides to just draw a "faded brush" across the middle of the neuron, hoping that some of it lands on the right spot. This "brush" will later be identified as noise, and will be ignored.

It also makes sense that the model will have a hard time generating precise values between 0.0 and 1.0, mimicking actual gaussians. Maybe we should aim for something more rigid.

Model Analysis¶

Before we continue we'll try to gain a better understanding of the trained heatmap regression model, using a suite of model analysis techniques.

Kernel Analysis¶

For starters, we can take a look at the learned kernels and see if we recognize any interesting patterns that the model may look for. Let's analyze the first convolutional layer, since it directly interacts with the input image:

In [ ]:
# extract and display the kernels of the first convolutional layer in the network
first_conv = heatmap_regressor.encode_blocks[0].forward_unit[0]
plt.figure(figsize=(6,6))
plt.suptitle('First Convolutional Kernels of Trained Heatmap Regressor', fontsize='xx-large')
DSProjectUtils.plot_fmap(first_conv.weight, layout_ratio=(1,1), to_normalize=True)

It doesn't seem like the first layer is after any apparent pattern, and since the rest of the layers interact with this layer's output only - there's no point in trying to understand the meaning of their kernels (which represent combinations of these "random" features). Thus, kernel analysis can't help us in this case.

Feature Maps View¶

Another thing we can try is to visualize the feature maps extracted by the network for a given sample. This can help us understand what kind of information helps the network make its predictions:

In [ ]:
# extract the representative sample we picked in advance
sample, target = heatmap_dataset[145]

# display sample
plt.figure(figsize=(16,4))
plt.imshow(sample.moveaxis(0,-1))
plt.title('Input Patch')
plt.axis('off')

# feed sample through the network and record activation history
with torch.no_grad():
    heatmap_regressor.eval()
    activation_hist = heatmap_regressor(sample.to(device).unsqueeze(0),
                                        get_hist=True)

# display this history on a sequence of plots
for i, activation in enumerate(activation_hist[1:-1]):
    plt.figure(figsize=(16,4))
    if activation.shape[1] > 3: layout_ratio = (1,4)
    else: layout_ratio = None
    DSProjectUtils.plot_fmap(activation[0], layout_ratio, cmap='gray')
    plt.suptitle(f'Block {i+1} Output Features')
    plt.subplots_adjust(top=.92)
    plt.show()

# display output
plt.figure(figsize=(16,4))
plt.imshow(activation_hist[-1].squeeze().cpu())
plt.title('Output Heatmap (Normalized)')
plt.axis('off')
None

It is clear that almost every feature map has (for the most part) some sort of fixed texture, interrupted by "anomalies" around interesting spots - which mostly contain objects that resemble CR+ neurons. Thus we can infer that the network is after such regions, looking to identify CR+ neurons and locate their center.

However, most of these "anomalies" are located around the oddly shaped neuron at the top-left of the input patch, while very few of them relate to the occurrences of small CR+-like stains. We may speculate that the network uses most of its "power" to process large cells (such as this one), since it's harder to identify the point that is considered to be their center (while the center of the small cells can easily be located later, using the information passed along by the U-Net architecture). This phenomenon takes place especially in the feature maps of lower resolutions (blocks 3-6), and we may explain it as the network's method to identify the "bounding ellipse" of larger cells.

The phenomenon I find the most interesting is the difference between the feature maps in the first half of the network and the feature maps in the second half of the network, and the change in the "nature" of the maps around the middle of the network (the "bottom" block of the U-Net). Up until the 5th block all "anomalies" in the feature maps had a natural shape - the shape of the same anomaly in the input patch first (blocks 1-2) and then the shape of an ellipse (blocks 3-4). However, starting from the 5th block, these ellipses "crack" and weird patterns start to appear along them (blocks 5-7). These patterns are later "formed" into "stable cracks" and holes in the middle of the ellipse (blocks 8-9). My intuition is that the network establishes these weird patterns in order to determine the center of the cell, which is later on "marked" by these "stable" patterns. All in all, we can speculate that the first half is mainly responsible for identifying the elliptic span of large cells, while the second half is mainly responsible for locating their exact center.

And as for the small stains, they return around the 7th block using the information passed along from the 2nd block - this is where the power of the U-Net architecture really comes into play (the network can focus on the difficult cells, without worrying that it may forget where the others reside). It's interesting to see that the feature maps that regard them (the "lighter" feature maps in block 9) barely contain information about the center of the larger cell. This may mean that later layers are "divided" into 2 distinct parts - one that processs small cells and another that processes large cells. Thus, it's safe to say that the learned model treats cells differently based on their size.

There are 2 additional small (but interesting) observations we can make:

  • You can notice that earlier feature maps (blocks 1 and 2) still regard the axons of the neuron at the top-left of the patch, while later feature maps do not include them. This may mean that the network has learned to discard the axons in order to make more accurate predictions, based only on the span of the neurons themselves.

  • Also, as we progress in the network, the feature maps become less "noisy" and more "smooth" (the frequencies of the maps decrease). We can assume that some part of the network is responsible for identifying and discarding the noisy background of the scanned samples.

Adversarial Attacks¶

Another method we can use to understand the nature of the network is to apply an adversarial attack on the network - synthesize an input patch to produce a wrong output.

To do this, we'll start with an input patch of our choice - $x$, on which the model will predict some correct output $y=f(x)$. We wish to apply a change to $x$ - turning into a synthesized patch we'll call $x'$ - such that the output of the model on the synthesized patch is a $y'$ of our choice (which is different from the correct $y=f(x)$). However, in order for the attack to be successful, $x'$ should be very similar to $x$ - thus discovering a vulnerability of the network as it predicts a wrong $y'$ on an input that looks like $x$ (whose favourable output is $y$).

The task at hand is to find that $x'$: a patch that's similar to $x$ on the one hand, that on the other hand, the model evaluates to $y'$. To do this, we can solve the following optimization problem:

$\DeclareMathOperator*{\minimize}{minimize}$ $\displaystyle{\minimize_{x'} \ \alpha \cdot L_{in}(x',x) + L_{out}(f(x'), y')}$

Where $L_{in}$ is a distance measure between input patches, $L_{out}$ is a distance measure between model outputs, and $\alpha$ is a tradeoff hyper-parameter (representing how important it is for us that the synthesized patch will be similar to the original). We'll use the Adam gradient-based optimizer to solve this optimization problem.

All in all, we can write a function to perform the operation:

In [ ]:
'''
Optimizes an adversarial attack on a model.
Input:
  > 'model' - the model to attack.
  > 'x' - the original input to use (PyTorch tensor with batch dimension).
  > 'y_target' - the target output we wish that the model will produce for the
    synthesized input (PyTorch tensor with batch dimension).
  > 'epochs' - the number of optimization epochs to run.
  > 'alpha' - a tradeoff hyper-parameter lambda (the larger it is, the more the
    synthesized input will be similiar to the original).
  > 'lr' - learning rate for the optimization process. Default: 1e-3.
  > 'desc' - we use the tqdm library to visualize the progress, this will be its
    description of it. Set to None if you don't want to visualize the progress.
    Default: None.
  > 'l_in' - loss measure between model inputs. Default: MSE loss.
  > 'l_out' - loss measure between model outputs. Default: MSE loss.
Returns a pair (x_tag_hist, y_pred_hist), where:
  > 'x_tag_hist' is a list containing the history of the synthesized inputs during
    the optimization process.
  > 'y_pred_hist' is a list containing the history of the output of the model on
    the synthesized input during the optimization process.
'''
def adversarial_attack(model, x, y_target, epochs, alpha, lr=1e-3, desc=None,
                       l_in=nn.functional.mse_loss, l_out=nn.functional.mse_loss):
    # copy and move everything to device
    x = x.clone().detach().to(device)
    y_target = y_target.clone().detach().to(device)

    # copy the input into a tensor that will be optimized
    x_tag = x.clone()
    x_tag.requires_grad = True

    # set an optimizer
    optimizer = optim.Adam([x_tag], lr=lr)

    # set returned historoes
    x_tag_hist = [x]
    y_pred_hist = []

    # run optimization steps
    model.eval()
    for _ in range(epochs) if desc is None else tqdm(range(epochs),desc):
        # clip x_tag to valid range and feed it through
        optimizer.zero_grad()
        clipped_x_tag = torch.clip(x_tag, 0.0, 1.0)
        y_pred = model(clipped_x_tag)
        # calculate and optimize loss
        loss = alpha * l_in(x_tag, x) + l_out(y_pred, y_target)
        loss.backward()
        optimizer.step()
        # record x_tag and y_pred
        x_tag_hist.append(clipped_x_tag.clone().detach())
        y_pred_hist.append(y_pred.clone().detach())

    # record final prediction
    with torch.no_grad():
        y_pred = model(x_tag)
        y_pred_hist.append(y_pred.clone().detach())

    return x_tag_hist, y_pred_hist

Now we can use this method to attack our model.

The first attack will optimize a patch that contains CR+ cells towards an empty heatmap (filled with zeros). In a way, the attack will try to efficiently "hide" the cells from the model, so we can understand what the model does not pay attention to.

In [ ]:
# extract the representative sample we picked in advance
sample, target = heatmap_dataset[97]

# run adversarial attack to produce empty heatmap (filled with zeros)
x = sample.unsqueeze(0)
y_target = torch.zeros(1,1,*x.shape[-2:])
x_tag_hist, y_pred_hist = adversarial_attack(
    heatmap_regressor, x, y_target,
    epochs=200, alpha=1e-2, lr=1e-2, desc='Optimizing attack on heatmap regressor'
)

# display results
anim = DSProjectUtils.animate_attack(x_tag_hist, y_pred_hist, y_target)
plt.close()
HTML(anim.to_html5_video())
Optimizing attack on heatmap regressor: 100%|██████████| 200/200 [00:07<00:00, 26.31it/s]
Out[ ]:
Your browser does not support the video tag.

The attack changed the color of the cells to green (and added some more bits of green around them) - thus we can fool the model by playing with the colors. Note that the brown color is composed of lots of red, and some green. Hence, the attack actually removed the red (and not necessarily "added green"). This may mean that the model determines the presence of a CR+ cell by the amount of red in the blob's complexion.

The second attack will optimize a patch that contains no CR+ cells towards a full heatmap (filled with ones). The attack will try to fool the model to think that there are CR+ neurons in the patch, and we can use it to try and understand what the model looks for.

In [ ]:
# extract the representative sample we picked in advance
sample, target = heatmap_dataset[123]

# run adversarial attack to produce full heatmap (filled with ones)
x = sample.unsqueeze(0)
y_target = torch.ones(1,1,*x.shape[-2:])
x_tag_hist, y_pred_hist = adversarial_attack(
    heatmap_regressor, x, y_target,
    epochs=200, alpha=2.5e0, lr=5e-3, desc='Optimizing attack on heatmap regressor'
)

# display results
anim = DSProjectUtils.animate_attack(x_tag_hist, y_pred_hist, y_target)
plt.close()
HTML(anim.to_html5_video())
Optimizing attack on heatmap regressor: 100%|██████████| 200/200 [00:07<00:00, 26.01it/s]
Out[ ]:
Your browser does not support the video tag.

This result is quite interesting: the attack strengthened the brown complexion of regions that were already brown by adding red to them, so their tone became similar to that of CR+ neurons in the training data. Thus, we may speculate that the model distinguishes between irrelevant stains and actual CR+ neurons by the amount of red in their tone - the redder it is, the more they're likely to be classified as CR+ neurons. This hypothesis explains the previous attack, that removed the red from the brown complexion of CR+ neurons to hide them from the model.

Gradient Maps¶

Last but not least, we'll analyze gradient maps of the error. The idea is to backpropagate the error of the model on a given sample and look at its gradient with respect to the input image. The gradient at a given pixel will represent the effect that the pixel has on the loss - how sensitive the loss is to changes in that pixel. The further a pixel's gradient is from 0.0, the more "problematic" it is. Note that in order to minimize the loss, we'd want to move the pixels' intensities in the opposite direction of their gradients.

We can use the auto-grad system of PyTorch to evaluate these gradients easily:

In [ ]:
# extract the representative sample we picked in advance
sample, target = heatmap_dataset[145]
sample = sample.to(device)
sample.requires_grad = True
target = target.to(device)

# feed sample through heatmap regressor and backpropagate loss
prediction = heatmap_regressor(sample.unsqueeze(0))
loss = criterion(prediction, target.unsqueeze(0))
loss.backward()

# obtain gradient maps (per color channel)
r_gmap, g_gmap, b_gmap = sample.grad

# prepare figure
fig, axes = plt.subplots(2,3, figsize=(15,10))
for ax in axes.flatten(): ax.axis('off')
divnorm = colors.TwoSlopeNorm(vcenter=0.0)

# plot relevant info

axes[0][0].imshow(sample.detach().moveaxis(0,-1).cpu())
axes[0][0].set_title('Input Patch')
axes[0][1].imshow(prediction.detach().squeeze().cpu(), vmin=0.0, vmax=1.0)
axes[0][1].set_title('Predicted Heatmap')
axes[0][2].imshow(target.squeeze().cpu(), vmin=0.0, vmax=1.0)
axes[0][2].set_title('Target Heatmap')

m = axes[1][0].imshow(r_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][0], fraction=0.046, pad=0.04)
axes[1][0].set_title('Red Channel Gradient Map')

m = axes[1][1].imshow(g_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][1], fraction=0.046, pad=0.04)
axes[1][1].set_title('Green Channel Gradient Map')

m = axes[1][2].imshow(b_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][2], fraction=0.046, pad=0.04)
axes[1][2].set_title('Blue Channel Gradient Map')

None

As one might expect - the "issues" rise around the predicted cells - as these areas are strongly associated with the corresponding "hills" on the predicted heatmap. Apart from that, the rest of the image has almost no effect, it's completely ignored. That means the model is very successful in identifying irrelevant regions.

The bottom pair of predicted cells are false positives, while the one at the top is a true positive that the model is not very confident about. As a result, their gradients have a reversed nature - the confidence in the formers should drop, while the confidence in the latter should rise. This reversed nature is visible in the gradient maps. For instance - the blue channel gradient map is high in the center of the top cell and low around it, while it's low in the center of the bottom cell and high around it.

The effect of the colors is very apparent in these maps:

  • The red gradient is mostly negative around the top cell (the true positive), and not so much around the others (the false positives). Thus, having stronger shades of red helps the model identify a cell in a region, as we've already observed.

  • Green and blue have an opposite effect - their gradients are high at the center of the true positive and low around it, and vice-versa for the false positives (this effect is stronger for the blue channel). Thus, a positively classified cell should have less of these shades at its center, and more of them around it.

All of this makes sense, as the CR+ cells are very brown - a color which is mostly composed of red, with bits of green and no blue. Adding green and blue makes the blobs' complexion less brown (and more yellowish, like the background). Moreover, blue regions in the data usually don't contain the required cells, so it makes sense that they should not be present where a CR+ neuron should be.

Second Attempt: Segmentation¶

The Idea and the Dataset¶

In order to solve both problems of the heatmap regressor, we need to take a different approach. A natural alternative is segmentation: instead of "drawing" gaussians around the center of the CR+ neurons, we can just identify all the pixels that are actually part of them. This is in fact a classification problem - classifying which pixels belong to a cell, and which are not.

This can be done by generating binary segmentation maps - binary maps of the same resolution, representing which pixels belong to a cell (marked with 1) and which are not (marked with 0):

Figure 14: Generating a segmentation map from a patch containing cellmarks.

This method solves both problems we had with the heatmap regressor:

  • The heatmap regressor had a hard time localizing large neurons, since it couldn't decide what exactly is considered to be its center. However, a segmentation model should not have such difficulty, since it can just mark the entire cell (and not only some precise point around its middle).

  • Segmentation maps are rigid targets - containing only zeros and ones, unlike the heatmap gaussians - containing entries between 0.0 and 1.0. We figured that such rigid targets may be easier to follow and learn, and it might be unnecessarily complex to try and mimic an exact gaussian.

There is one big technical issue with this approach: how would we get the data to fit a model on? We'll need marked segmentation maps, containing extensive information about the precise span of the CR+ neurons. If you recall, as we explored the data at hand in the "Preparing the Data" section, we noticed that the XML annotations file contains a few annotated cell perimeters. We may want to use these to generate binary segmentation maps. However - there were very few of them and they were all clustered around the bottom-left of the WSI, which makes them a bad source of data that is not representative enough.

Nevertheless, you may also recall that we actually already generated such maps using image processing techniques: in the "Processing the Cells" section, this was one of the steps to generate the diameter marker of a localized cell. We can use this functionality on the localized cells in the training dataset to generate target segmentation maps.

All in all, let's write a PyTorch dataset that will generate those segmentation maps for extracted image tiles:

In [ ]:
'''
This class represents a dataset for detecting multiple brain cells in IHC stained
  whole-slide-images (WSI), based on generated segmenation maps of localized cells.
'''
class SegmentationDataset(Dataset):
    '''
    Constructs a new dataset.
    Input:
      > 'img_handle' - an opened pytiff handle object for the TIFF slide-image.
      > 'patches' - list of patch dictionaries, as returned from 'extract_tiles'.
        It is assumed that all patches have the same shape.
      > 'patch_shape' - shape of the patches in 'patches'.
      > 'trans' - a callable that will be called on the image patches before
        returning them. Default: None (no transformation will be applied).
    '''
    def __init__(self, img_handle, patches, patch_shape, trans=None):
        # store relevant info
        patch_height, patch_width = patch_shape
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.img_handle = img_handle
        self.patches = patches
        self.trans = trans

    '''
    Returns the number of elements in this dataset.
    '''
    def __len__(self):
        # the number of elements is the number of patches
        return len(self.patches)

    '''
    Gets an element from this dataset.
    Input - index of the element.
    Returns a pair (img, target), where:
    > 'img' is a PyTorch tensor of shape (3,h,w) containing a patch of the tiff image,
      which may or may not contain cells.
    > 'target' is a PyTorch tensor of shape (1,h,w) with binary entries (0.0 or 1.0).
      The '1' entries correspond to pixels in the input images that lie on a target
      cell, and the '0' entries correspond to background pixels.
    '''
    def __getitem__(self,  i):
        # extract image patch
        patch = self.patches[i]
        top_left_x, top_left_y = patch['top_left_corner']
        img = self.img_handle[top_left_y : top_left_y + self.patch_height,
                                top_left_x : top_left_x + self.patch_width]

        # calculate cell centers (center of the markers)
        cell_centers = [(math.floor((y1+y2)/2), math.floor((x1+x2)/2))
                        for (x1, y1), (x2, y2) in patch['cellmarks']]

        # to speed things up we decided to discard cells whose center is outside the
        # patch (even if some of it is inside; we tried both ways and the difference
        # was very minor)
        cell_centers = filter(lambda center : \
                                  0 <= center[0] < self.patch_height and \
                                  0 <= center[1] < self.patch_width,
                              cell_centers)

        # obatin segmentation maps
        percell_seg_maps = [get_cell_segmentation(img, center)
                            for center in cell_centers] + [np.zeros(img.shape[:-1])]
        target = np.maximum.reduce(percell_seg_maps)

        # prepare output
        img = torch.tensor(img).movedim(-1,0)
        target = torch.tensor(target, dtype=torch.float32).unsqueeze(0)
        if self.trans != None:
          img = self.trans(img)
          target = self.trans(target)
        return img, target

Let's check it out:

In [ ]:
# create a segmentation dataset
segment_dataset = SegmentationDataset(img_handle, small_patches,
                                      small_patch_shape, trans=trans)
print(f'Segmentation dataset:\n',
      f'\tPatch resolution: {small_patch_shape[1]}px X {small_patch_shape[0]}px\n',
      f'\tSize: {len(segment_dataset)} samples')
Segmentation dataset:
 	Patch resolution: 256px X 256px
 	Size: 11367 samples
In [ ]:
# display some samples from the dataset
sample_idxs = [83,987,121]

for sample_idx in sample_idxs:
    # extract and organize the sample for displaying
    img, target = segment_dataset[sample_idx]
    img = img.moveaxis(0,-1).numpy()
    target = target.squeeze().numpy()
    # display its contents
    fig, axes = plt.subplots(1,2, figsize=(7,7))
    axes[0].imshow(img)
    axes[0].set_title(f'Source Patch {sample_idx}')
    axes[0].axis('off')
    axes[1].imshow(target)
    axes[1].set_title(f'Target Segmentation Map {sample_idx}')
    axes[1].axis('off')

Generating Segmentation Maps: the U-Net Strikes Again¶

The task at hand is to classify the pixels in the input images into 2 classes: pixels that belong to CR+ neurons, and pixels that do not. However, binary classification as it is is not a differentiable process, and a common practice in the machine learning community is to use the famous logistic regression instead - which is all about regressing the probability that some input belongs to one class or another.

So in practice, we'll need to generate the probability that each pixel belongs to a neuron. There are many pixels, and we might as well do it in parallel - by generating a probability map of the same resolution as the input image, whose entries represent the required probabilities for the corresponding pixels.

All in all - we face a pretty similar task to heatmap-regression: in both cases, we generate a 1-channel map of values between 0.0 and 1.0 with the same resolution as the input image. Since we already implemented and tested the U-Net encoder-decoder architecture, we might as well use it for this task as well. In fact, the original paper that proposed this popular architecture (O Ronneberger et al., 2015) used it for a segmentation task (of biomedical images, just like us!)

As in classic logistic regression models, we'll use the sigmoid activation function to produce the final probability map. All in all, this is the architecture we'll use:

In [ ]:
# construct a UNet for our purpose - segmentation (probability-map regression)
segmentator = UNet(
    encode_blocks =  [UNetDownBlock(in_channels=3, out_channels=16, kernel_size=7),
                      UNetDownBlock(in_channels=16, out_channels=32, kernel_size=7),
                      UNetDownBlock(in_channels=32, out_channels=64, kernel_size=5),
                      UNetDownBlock(in_channels=64, out_channels=128, kernel_size=3)],

    bottom_block =    UNetDownBlock(in_channels=128, out_channels=128, kernel_size=3).forward_unit,

    decode_blocks =  [UNetUpBlock(in_channels=128, skip_con_channels=128, out_channels=64, kernel_size=3),
                      UNetUpBlock(in_channels=64, skip_con_channels=64, out_channels=32, kernel_size=5),
                      UNetUpBlock(in_channels=32, skip_con_channels=32, out_channels=16, kernel_size=7),
                      UNetUpBlock(in_channels=16, skip_con_channels=16, out_channels=16, kernel_size=7)],

    final_block =    nn.Sequential(nn.Conv2d(in_channels=16, out_channels=8, kernel_size=5, padding='same'),
                                   nn.Conv2d(in_channels=8, out_channels=1, kernel_size=1, padding='same'),
                                   nn.Sigmoid())
)
segmentator = segmentator.to(device)

Let's calculate the number of parameters associated with such a model:

In [ ]:
num_params = sum(param.numel() for param in segmentator.parameters())
print(f'Segmentation model contains: {num_params} parameters')
Segmentation model contains: 1265009 parameters

This is the same number of parameters that the heatmap regressor contains. This makes sense, as both models are based on the same underlying architecture.

Fitting the Segmentation Model¶

Let's start by creating a dataset and a dataloader for our purpose. When experiementing with the data and the model, we figured that using small input patches of size 256px by 256px is a good fit:

In [ ]:
# create dataset & dataloader for segmentation
segment_dataset = SegmentationDataset(img_handle, small_patches,
                                      patch_shape=small_patch_shape,
                                      trans=trans)
segment_dataloader = torch.utils.data.DataLoader(
    segment_dataset, batch_size=64, num_workers=2, shuffle=True, pin_memory=True
)

As this task is almost identical to the heatmap regression task we tackled previously, we can use the same training procedure that we already defined (which is a general training procedure for encoder-decoder architectures).

One crucial difference is that we'll use the cross-entropy error measure instead of MSE, as this is a classification task. Apart from that, we'll use a similar training "environment" - an Adam optimizer and an exponential learning-rate decay:

In [ ]:
# prepare representative sample (that we picked in advance)
rep_sample, rep_target = segment_dataset[121]

# attach an optimizer and lr-decay scheduler, and fit the model using Adam and BCELoss
criterion = nn.BCELoss()
optimizer = optim.Adam(segmentator.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5**0.5)

loss_hist, rep_pred_hist = train_endec_model(segmentator, optimizer,
                                             segment_dataloader,
                                             criterion, epochs=5,
                                             rep_sample=rep_sample,
                                             scheduler=scheduler)

# save model and optimizer
torch.save(segmentator.state_dict(), 'models/segmentator')
torch.save(optimizer.state_dict(), 'optimizers/segmentator')
epoch 1 / 5: 100%|██████████| 178/178 [11:23<00:00,  3.84s/it]
epoch 2 / 5: 100%|██████████| 178/178 [11:20<00:00,  3.83s/it]
epoch 3 / 5: 100%|██████████| 178/178 [11:26<00:00,  3.86s/it]
epoch 4 / 5: 100%|██████████| 178/178 [11:25<00:00,  3.85s/it]
epoch 5 / 5: 100%|██████████| 178/178 [11:29<00:00,  3.87s/it]
Training completed in 57m 6s

Let's take a look at the progress of the model through time:

In [ ]:
# plot loss history (skip first 50 steps due to sharp drops)
plt.plot(range(50, len(loss_hist)), loss_hist[50:])
plt.xlabel('Training Step')
plt.ylabel('BCE-Loss')
plt.title('Segmentation Loss History (BCE)')
None
In [ ]:
'''
In this cell we animate the progress of the model on the representative sample
  through time
'''

# organize the data to show
rep_sample = rep_sample.moveaxis(0,-1)
rep_target = rep_target.squeeze()
rep_pred_hist = [rep_pred.squeeze() for rep_pred in rep_pred_hist]

# prepare figure
fig, axes = plt.subplots(1,3,figsize=(12,5))
for ax in axes: ax.axis('off')

# prepare plots
axes[0].imshow(rep_sample)
axes[0].set_title('Input Patch')
pred_axes_img = axes[1].imshow(rep_pred_hist[0], vmin=0, vmax=1);
axes[1].set_title('Predicted Probability Map')
axes[2].imshow(rep_target, vmin=0, vmax=1)
axes[2].set_title('Target Segmentation Map')

# add colorbar
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(pred_axes_img, cax=cbar_ax)

# prepare animation
def drawframe(n):
    pred_axes_img.set_data(rep_pred_hist[n])
    return [pred_axes_img]
anim = animation.FuncAnimation(fig, drawframe, frames=len(rep_pred_hist),
                               interval=20, blit=True)

# display animation (and only animation)
plt.close()
HTML(anim.to_html5_video())
Out[ ]:
Your browser does not support the video tag.

It seems like the model's output on the representative sample oscillates a lot starting from the 2nd second of the animation - which could mean that the model has started to converge, but has a hard time getting to the local optimum. This means that we may want to use a lower learning-rate, or alternatively strengthen the learning-rate decay. Time is too short to make this adjustment, so we'll leave the model as it is.

Since this is a fully-convolutional model as well, we can use it on input images of any resolution. Let's check it out on an instance of a dataset of larger patches:

In [ ]:
# build dataset of patcehs of an higher resolution (512,2048)
large_segment_dataset = SegmentationDataset(img_handle, large_patches,
                                            patch_shape=large_patch_shape,
                                            trans=trans)

# extract a patch from it (which we picked in advance)
patch_idx = 3
patch, target_segment = large_segment_dataset[patch_idx]

# evaluate the model on it
with torch.no_grad():
    segmentator.eval()
    output = segmentator(patch.unsqueeze(0).to(device))

# prepare data for visualization
pred_segment = output.cpu().squeeze()
patch = patch.moveaxis(0,-1)
target_segment = target_segment.squeeze()

# prepare figure
fig, axes = plt.subplots(3,1, figsize=(15,12))
for ax in axes: ax.axis('off')

# plot
axes[0].imshow(patch)
axes[0].set_title('Input Patch')
axes[1].imshow(pred_segment, vmin=0, vmax=1)
axes[1].set_title('Predicted Segmentation Map')
axes[2].imshow(target_segment, vmin=0, vmax=1)
axes[2].set_title('Target Segmentation Map')
plt.subplots_adjust(hspace=0.1)
None

This looks way better! There is no issue with the large neuron at the bottom. There are still a few false positives, which may be CR+ neurons that the research team genuinely missed. Unfortunately, time is too short to tackle these as well. However, we mention a possible solution in the final section of the project.

Model Analysis¶

Before we continue we'll try to gain a better understanding of the trained segmentation model, using the same suite of model analysis techniques we used to analyze the heatmap regressor.

Kernel Analysis¶

For starters, we can take a look at the learned kernels and see if we recognize any interesting patterns that the model may look for. Let's analyze the first convolutional layer since it directly interacts with the input image:

In [ ]:
# extract and display the kernels of the first convolutional layer in the network
first_conv = segmentator.encode_blocks[0].forward_unit[0]
plt.figure(figsize=(6,6))
plt.suptitle('First Convolutional Kernels of Trained Segmentation Model', fontsize='xx-large')
DSProjectUtils.plot_fmap(first_conv.weight, layout_ratio=(1,1), to_normalize=True)

As with the heatmap regression model, it doesn't seem like the first layer of the segmentation model is after any apparent pattern, and since the rest of the layers interact with this layer's output only - there's no point in trying to understand the meaning of their kernels (which represent combinations of these "random" features). Thus, kernel analysis can't help us in this case as well.

Feature Maps View¶

Viewing the feature maps extracted by the network helped us gain a thorough understanding of the heatmap regressor, and we might as well use it here. This can indicate what kind of information helps the segmentation model make its predictions.

In [ ]:
# extract the representative sample we picked in advance
sample, target = segment_dataset[145]

# display sample
plt.figure(figsize=(16,4))
plt.imshow(sample.moveaxis(0,-1))
plt.title('Input Patch')
plt.axis('off')

# feed sample through the network and record activation history
with torch.no_grad():
    segmentator.eval()
    activation_hist = segmentator(sample.to(device).unsqueeze(0),
                                        get_hist=True)

# display this history on a sequence of plots
for i, activation in enumerate(activation_hist[1:-1]):
    plt.figure(figsize=(16,4))
    if activation.shape[1] > 3: layout_ratio = (1,4)
    else: layout_ratio = None
    DSProjectUtils.plot_fmap(activation[0], layout_ratio, cmap='gray')
    plt.suptitle(f'Block {i+1} Output Features')
    plt.subplots_adjust(top=.92)
    plt.show()

# display output
plt.figure(figsize=(16,4))
plt.imshow(activation_hist[-1].squeeze().cpu())
plt.title('Output Probaility Map (Normalized)')
plt.axis('off')
None

The cell segmentation task is all about identifying the span of the cells and discarding irrelevant parts (such as axons or noise of any kind). However, in addition to identifying the span of the cells, the heatmap regression model also needs to localize the point that is considered to be its center and draw a gaussian-like stain around it. Thus cell segmentation may be considered to be an easier task, that requires less complex analysis of the input patches. And indeed - when looking at the feature maps that the segmentator extracted, it's hard to find interesting operations that the model performed (unlike with the heatmap regressor).

Another phenomenon that supports this hypothesis is that quite a few feature maps that the segmentator extracted were pretty empty, with hard-to-see anomalies around not-interesting regions - mostly along the edges of the map (which may be due to convolutional padding that confused the network). This can mean that the network didn't really need all its power to solve the segmentation task, and that a smaller network may also be able to do the trick.

However, we can still make a few observations about the computational process of the model. For starters, just like the heatmap regressor, almost every feature map extracted by the segmentator has (for the most part) some sort of fixed texture that's interrupted by "anomalies" around interesting spots - which mostly contain objects that resemble CR+ neurons. Thus we can infer that the network is after such regions, looking to identify the span of CR+ neurons.

Also, as we progress in the network, the feature maps become less "noisy" and more "smooth" (the frequencies of the maps decrease). We can assume that some part of the network is responsible for identifying and discarding the noisy background of the scanned samples.

Since some of the axons may still belong to the span of the cell, they are not completely ignored. However, they disappear around the 4th block and come back later around the 8th block (using the information passed along by the U-Net architecture). The same happens with the small cell-like stains, which disappear around the 3rd block and come back around the 8th block. This may be due to resolution constraints: it's harder to express thin axons and small stains in low-resolution maps. Also, one may assume that the network temporarily ignores them so it can analyse the elliptic span of the larger cell (which is harder to process).

Adversarial Attacks¶

We can also use adversarial attacks to find out the segmentation model's weaknesses and infer the way it operates, just like with the heatmap regression model. We'll use the adversarial_attack routine we implemented previously, with BCE loss to compare the model's output to the required map (since the output is a probability map).

The first attack will optimize a patch that contains CR+ cells towards an empty segmentation map (filled with zeros). In a way, the attack will try to efficiently "hide" the cells from the model, so we can understand what the model does not pay attention to.

In [ ]:
# extract the representative sample we picked in advance
sample, target = segment_dataset[97]

# run adversarial attack to produce segmentation map (filled with zeros)
x = sample.unsqueeze(0)
y_target = torch.zeros(1,1,*x.shape[-2:])
x_tag_hist, y_pred_hist = adversarial_attack(
    segmentator, x, y_target,
    epochs=150, alpha=1e-1, lr=5e-3, desc='Optimizing attack on segmentator',
    l_out=nn.BCELoss()
)

# display results
anim = DSProjectUtils.animate_attack(x_tag_hist, y_pred_hist, y_target)
plt.close()
HTML(anim.to_html5_video())
Optimizing attack on segmentator: 100%|██████████| 150/150 [00:05<00:00, 25.64it/s]
Out[ ]:
Your browser does not support the video tag.

The attack discolored the strong brown shades of the actual cells in the image by taking away the red in them, making them more gray. This way, the attack managed to hide the presence of the cell from the model. Thus, we can assume that the model strongly associates cells with red-brown tones, just like the heatmap regressor.

For some reason, the area around them also became a bit more green. It's hard to understand why this helps hide the cells from the model though.

The second attack will optimize a patch that contains no CR+ cells towards a full heatmap (filled with ones). The attack will try to fool the model to think that there are CR+ neurons in the patch, and we can use it to try and understand what the model looks for.

In [ ]:
# extract the representative sample we picked in advance
sample, target = segment_dataset[123]

# run adversarial attack to produce full heatmap (filled with ones)
x = sample.unsqueeze(0)
y_target = torch.ones(1,1,*x.shape[-2:])
x_tag_hist, y_pred_hist = adversarial_attack(
    segmentator, x, y_target,
    epochs=300, alpha=7.5e2, lr=1e-3, desc='Optimizing attack on segmentator',
    l_out=nn.BCELoss()
)

# display results
anim = DSProjectUtils.animate_attack(x_tag_hist, y_pred_hist, y_target)
plt.close()
HTML(anim.to_html5_video())
Optimizing attack on segmentator: 100%|██████████| 300/300 [00:11<00:00, 26.01it/s]
Out[ ]:
Your browser does not support the video tag.

The attack messed with the color of some regions that were already pretty brown by adding lots of green and some red to them, making them more green-brown-ish. Thus, we may speculate that the model strongly associates green shades with CR+ neurons, as well as bits of red. This makes sense, as these cells usually have a strong brown complexion - a tone that's composed of lots of red with some green.

Gradient Maps¶

Last but not least, we'll analyse gradient maps of the error of the model, with respect to an input sample - just like with the heatmap regressor. This way we can determine the effect each pixel has on the loss - how sensitive the loss is to changes in that pixel. Note that in order to minimize the loss, we'd want to move the pixels' intensities in the opposite direction of their gradients.

We'll use the auto-grad system of PyTorch to evaluate these gradients easily:

In [ ]:
# extract the representative sample we picked in advance
sample, target = segment_dataset[145]
sample = sample.to(device)
sample.requires_grad = True
target = target.to(device)

# feed sample through segmentator and backpropagate loss
prediction = segmentator(sample.unsqueeze(0))
loss = criterion(prediction, target.unsqueeze(0))
loss.backward()

# obtain gradient maps (per color channel)
r_gmap, g_gmap, b_gmap = sample.grad

# prepare figure
fig, axes = plt.subplots(2,3, figsize=(15,10))
for ax in axes.flatten(): ax.axis('off')
divnorm = colors.TwoSlopeNorm(vcenter=0.0)

# plot relevant info

axes[0][0].imshow(sample.detach().moveaxis(0,-1).cpu())
axes[0][0].set_title('Input Patch')
axes[0][1].imshow(prediction.detach().squeeze().cpu(), vmin=0.0, vmax=1.0)
axes[0][1].set_title('Predicted Probability Map')
axes[0][2].imshow(target.squeeze().cpu(), vmin=0.0, vmax=1.0)
axes[0][2].set_title('Target Segmentation')

m = axes[1][0].imshow(r_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][0], fraction=0.046, pad=0.04)
axes[1][0].set_title('Red Channel Gradient Map')

m = axes[1][1].imshow(g_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][1], fraction=0.046, pad=0.04)
axes[1][1].set_title('Green Channel Gradient Map')

m = axes[1][2].imshow(b_gmap.detach().cpu(), cmap='seismic', norm=divnorm)
plt.colorbar(m, ax=axes[1][2], fraction=0.046, pad=0.04)
axes[1][2].set_title('Blue Channel Gradient Map')

None

Just like with the heatmap regressor, the "issues" rise around the predicted cells - as these areas are strongly associated with the corresponding "hills" on the predicted heatmap. Apart from that, the rest of the image has almost no effect, it's completely ignored. That means that the model is pretty successful in identifying irrelevant regions.

It seems like the blue channel has the most coherent affect - the gradient is positive around pixels that should be positively-classified, and negative around pixels that should be negatively-classified. Thus, the model strongly associates blue shades with regions that contain no CR+ cells. This makes sense, as blue regions in the training data are usually empty.

The red channel gradient map and the green channel gradient map have a similar nature - the gradients around the problematic areas are a mix of positive bits and negative bits. The similarity in their nature is reasonable, as we've already seen that the model associates both of these shades with the presence of CR+ neurons (due to the brown tone they compose). Their "messy" nature is harder to explain though. It may be due to some unknown "reasoning" of the model.

Extracting Cell Coordinates¶

The target goal of this section was to localize CR+ cells in input patches. However, so far we only managed to generate maps that indicate their location - but do not provide their exact coordinates. We'll use 2 methods to extract those - one suitable for heatmaps, and another suited for segmentation maps.

Extracting from Heatmaps: Non-Maximum Suppression¶

To extract cell coordinates from generated heatmaps, we'll use non-maximum suppression - a well-known technique in the field of Computer Vision. The idea is to pick pixels whose value in the heatmap is high, but "suppress" pixels whose value is not the highest in some neighbourhood around them - thus only choosing the "peaks" of the "gaussians" (which should be the centers of the localized cells). Thankfully, PyTorch provides a quick implementation of this mechanism, which we'll use.

However, sometimes noisy predictions may cause small "hills" on the maps, which we mostly don't wish to classify as cells. To suppress these as well we'll use a minimum threshold to filter small "noisy" entries.

Let's combine these ideas into a single function that will extract cell coordinates from a given map:

In [ ]:
'''
Uses non-maximum supression to extract coordinates of cells from a generated heatmap
  of localized cells.
Input:
  > 'map' - the generated map (2-dimensional torch tensor).
  > 'threshold' - threshold for filtering small "noisy" entries in the map.
    Default: 0.1.
  > 'w_margin' & 'h_margin' - horizontal & vertical margins betweeen two cells
    (in pixels), respectively. Default: 5 & 5. If two cells are closer, only one
    of them will be extracted (according to 'iou_threshold').
  > 'iou_threshold' - IOU threshold for non-maximum-suppresion (see torchvision.ops.nms).
Returns a tensor of shape (num_of_cells,2), whose rows are the coordinate pairs
  of the cells (row number and then column number).
'''
def extract_coords_nms(map, threshold=0.25, w_margin=5, h_margin=5, iou_threshold=0.0):
    # filter map entries below threshold
    pos_coords = torch.argwhere(map >= threshold)
    pos_x_coords = pos_coords[:,1]
    pos_y_coords = pos_coords[:,0]

    # extend them to "pseudo bounding-boxes" (for torchvision NMS implementation)
    pos_boxes = torch.column_stack([
        pos_x_coords,
        pos_y_coords,
        pos_x_coords + torch.full_like(pos_x_coords, w_margin),
        pos_y_coords + torch.full_like(pos_y_coords, h_margin)
    ]).float()

    # collect scores
    scores = map[pos_y_coords, pos_x_coords]

    # non-maximum-suppression
    coords_idxs = torchvision.ops.nms(pos_boxes, scores, iou_threshold=iou_threshold)
    return pos_coords[coords_idxs]

Let's check it out on a sample of the "large-patches" dataset we constructed previously, using the trained heatmap regressor:

In [ ]:
# extract a patch from it (which we picked in advance)
patch_idx = 5
patch, target_heatmap = large_heatmap_dataset[patch_idx]
target_cell_coords = extract_coords_nms(target_heatmap.squeeze())
vis_patch = patch.moveaxis(0,-1) # for visualization

# evaluate the model on it
with torch.no_grad():
    heatmap_regressor.eval()
    output = heatmap_regressor(patch.unsqueeze(0).to(device))

# prepare output for visualization
pred_heatmap = output.cpu().squeeze()
filtered_pred_heatmap = torch.where(pred_heatmap >= 0.25, pred_heatmap,
                                   torch.zeros_like(pred_heatmap))

# prepare figure
fig, axes = plt.subplots(5,1, figsize=(12,17))
for ax in axes: ax.axis('off')

# plot original patch and targets
axes[0].set_title('Input Patch')
axes[0].imshow(vis_patch)
axes[1].set_title('Target Cell Locations')
DSProjectUtils.plot_pinned_patch(axes[1], vis_patch, target_cell_coords)

# plot prediction info
axes[2].set_title('Predicted Heatmap')
axes[2].imshow(pred_heatmap, vmin=0, vmax=1)
axes[3].set_title('Filtered Predicted Heatmap')
axes[3].imshow(filtered_pred_heatmap, vmin=0, vmax=1)
axes[4].set_title('Predicted Cell Locations')
DSProjectUtils.plot_pinned_patch(
    axes[4], vis_patch,
    extract_coords_nms(pred_heatmap)
)

Extracting from Segmentation Maps: Connected Component Analysis¶

Segmentation maps tell us which pixels are classified as pixels that belong to CR+ neurons. Following the reasonable assumption that each neuron lies on a continuous region of the image, we can obtain the exact span of the neurons by analysing the connected components of positively-classified pixels. Each of these will represent the span of a single neuron (assuming that none of them overlap), whose center can later be identified as the component's center of mass or centroid.

We'll use a probability threshold to determine which pixels are positively-classified as pixels that lie on CR+ neurons, and will use the SciPy library to extract the required components.

All in all, let's write a single function that will extract cell coordinates from a given probability segmentation map:

In [ ]:
'''
Uses connected component analysis to extract coordinates of cells from a generated
  probability segmentation map of localized cells.
Input:
  > 'map' - the generated map (2-dimensional torch tensor).
  > 'threshold' - probability threshold for positively classified pixels. Default:
    0.5.
Returns a tensor of shape (num_of_cells,2), whose rows are the coordinate pairs
  of the cells (row number and then column number).
'''
def extract_coords_cca(map, threshold=0.5):
    # locate connected components
    labels, num_features = ndimage.label(map >= threshold)
    # calculate centeroids
    centroids = ndimage.measurements.center_of_mass(np.ones_like(labels), labels,
                                                    range(1,num_features+1))
    return torch.tensor(centroids).floor().to(torch.int32)

Let's test it using the segmentation map generated for the same sample, using the trained segmentation model:

In [ ]:
# evaluate the model on the sample we previously extracted
with torch.no_grad():
    segmentator.eval()
    output = segmentator(patch.unsqueeze(0).to(device))

# prepare output for visualization
pred_segment_probs = output.cpu().squeeze()
pred_segments = (pred_segment_probs >= 0.5)

# prepare figure
fig, axes = plt.subplots(5,1, figsize=(12,17))
for ax in axes: ax.axis('off')

# plot original patch and targets
axes[0].set_title('Input Patch')
axes[0].imshow(vis_patch)
axes[1].set_title('Target Cell Locations')
DSProjectUtils.plot_pinned_patch(axes[1], vis_patch, target_cell_coords)

# plot prediction info
axes[2].set_title('Predicted Segmentation Probability Map')
axes[2].imshow(pred_segment_probs, vmin=0, vmax=1)
axes[3].set_title('Predicted Segmentation Map')
axes[3].imshow(pred_segments)
axes[4].set_title('Predicted Cell Locations')
DSProjectUtils.plot_pinned_patch(
    axes[4], vis_patch,
    extract_coords_cca(pred_segment_probs)
)

End to End System¶

Now, after we've constructed the required functionalities to automatically locate CR+ neurons and extract their diameter, we can connect to pieces together into an end-to-end system that will perform the job that was previously done manually by brain-science experts.

Let's recall the task at hand. We have whole-slide scans of putamen samples in the form of TIFF images, alongside XML files marking the relevant regions in them. We need to scan these regions, identify CR+ cells in them, mark their diameter and record their length.

The Pipeline¶

In this part we'll use the functionalities we implemented throughout the project to construct the entire pipeline that will complete the required task.

Since the scans take a few gigabytes of memory and thus can't be loaded into main memory all at once, we use tiling: we divide the image into small patches in the relevant regions and process them separately. Let's connect the tiling functionality we implemented in the "Processing the cells" section to a PyTorch dataset, so we could later take advantage of the PyTorch CUDA compatibility and process these patches simultaneously on GPU:

In [ ]:
'''
This class represents a dataset of tiles of a whole-slide image, extracted by the
  "extract_tiles" functionality.
'''
class TilesDataset(Dataset):
    '''
    Constructs a new dataset.
    Input:
      > 'img_handle' - an opened pytiff handle object for the TIFF slide-image.
      > 'tiles' - list of patch dictionaries, as returned from 'extract_tiles'.
        It is assumed that all patches have the same shape.
      > 'patch_shape' - shape of the tiles in 'tiles'.
    '''
    def __init__(self, img_handle, tiles, patch_shape):
        # store relevant info
        self.img_handle = img_handle
        self.patches = tiles
        self.patch_height, self.patch_width = patch_shape

    '''
    Returns the number of elements in this dataset.
    '''
    def __len__(self):
        # the number of elements is the number of tiles
        return len(self.patches)

    '''
    Gets an element from this dataset.
    Input - index of the element.
    Returns a pair: (img, coords), where:
      > 'img' is a PyTorch float tensor of shape (3,*patch_shape) representing the
        extracted tile.
      > 'coords' - a PyTorch tensor of shape (2,) containing the coordinates of the
        top left corner of the patch in the whole-slide image (row number and then
        column number).
    '''
    def __getitem__(self,  i):
        # extract image patch
        patch = self.patches[i]
        top_left_x, top_left_y = patch['top_left_corner']
        img = self.img_handle[top_left_y : top_left_y + self.patch_height,
                                top_left_x : top_left_x + self.patch_width]

        # prepare output
        img = torch.tensor(img).movedim(-1,0)
        img = convert_image_dtype(img, torch.float)
        top_left_corner = torch.tensor([top_left_y, top_left_x])
        return img, top_left_corner

Next, we need to identify the CR+ neurons in the patches using one of the 2 cell localization methods we previously implemented: heatmap regression + non-maximum suppression, or cell segmentation + connected component analysis. We'll use a dataloader to feed the patches through the networks in batches, thus taking adventage of parallel computation. Let's write a routine to extract cell coordinates in the entire slide image, using an instance of the tiles dataset we just implemented and an extraction method of the user's choice:

In [ ]:
'''
Extracts coordinates of localized CR+ neurons in a while-slide scan of a putamen
  sample.
Input:
  > 'tiles_dataloader' - a PyTorch dataloader for an instance of a TilesDataset,
    which is associated with the relevant patches in the slide scan.
  > 'model' - cell localization model (a heatmap regression network or a segmentation
    network).
  > 'extraction_method' - the method used to extract cell coordinates from the
    localization maps produced by 'model'. Should be one of the following two strings:
    - 'nms' for non-maximum suppression (suited for a heatmap regression model).
    - 'cca' for connected component analysis (suited for a segmentation model).
    Default hyper-parameters are used.
Returns: a tensor of shape (num_of_cells,2), whose rows are the coordinate pairs
  of the cells (row number and then column number).
'''
def detect_cells(tiles_dataloader, model, extraction_method):
    # validate extraction method
    if extraction_method is 'nms':
        extract_coords = extract_coords_nms
    elif extraction_method is 'cca':
        extract_coords = extract_coords_cca
    else:
        raise ValueError('\'extraction_method\' should be either \'nms\' or \'cca\'.')

    # initialize list of cell coordinate tensors for the patches
    coords_tensors_list = []
    model.eval()

    # iterate through the tile patches
    for patches, coord_pairs in tqdm(tiles_dataloader, desc='Detecting CR+ neurons'):
        # generate localization maps
        with torch.no_grad():
            maps = model(patches.to(device))
        # extract cell coordinates from each map
        for map, top_left_coords in zip(maps, coord_pairs):
            coords_tensors_list.append(extract_coords(map.squeeze().cpu()) +
                                       top_left_coords)

    return torch.cat(coords_tensors_list)

After we detected the required neurons in the slide scan, we can use the method we developed in the "Processing the Cells" section to identify their diameter. To do this, we need to process windows of the image around them. We experimented with the data at hand and figured that using a rectangular window of resolution 110pxs by 110pxs around each cell is best. Let's write a function that will do exactly that, and produce the edge coordinates of the detected cells' diameters:

In [ ]:
'''
Processes deteceted CR+ cells in a slide image, and identifies their diameters.
Input:
  > 'cell_coords_tensor' -  a tensor of shape (num_of_cells,2), whose rows are the
    coordinate pairs of the cells' centers (row number and then column number).
  > 'img_handle' - an opened pytiff handle object for the TIFF slide-image.
  > 'window_shape' - shape of window around cells that's used to identify their
    diameter. Note: it's assumed that the every window around every cell is completely
    inside the TIFF image. Default: (110,110)
Returns a numpy array of shape (num_of_cells,2,2), whose (n,i)'th entry is the coordinate
  vector for the i'th edge of the n'th cell diameter (row number and then column
  number). These are detected using the improved percell method.
'''
def identify_diameters(cell_coords_tensor, img_handle, window_shape=(110,110)):
    # extract window shape
    window_h, window_w = window_shape

    # initialize list of diameter pairs
    diameter_list = []

    # process each cell individually
    for (r,c) in tqdm(cell_coords_tensor, 'Identifying cell diameters'):

        # process the required window around the cell
        window = img_handle[r-window_h//2 : r+window_h//2,
                            c-window_w//2 : c+window_h//2]
        diameter_points = get_diameter_points(window, (window_h//2, window_w//2),
                                              segmentator=new_get_cell_segmentation)
        diameter_list.append(np.array(diameter_points) +
                             np.array([r-window_h//2, c-window_w//2]))

    return np.stack(diameter_list)

Finally, we can connect everything into an end-to-end system that can run on any scanned sample, provided the paths for the files associated with it:

In [ ]:
'''
The whole pipeline of our project: locates CR+ neurons in a given sample of the
  putamen region of the brain (scanned using whole-slide scanning), and extracts
  the neurons' diameters.

Input:
  > 'tiff_path' - path for the TIFF image containing the scanned sample.
  > 'xml_path' - path for the associated XML annotations file that contains information
    about the relevant regions.
  > 'tile_shape' - shape of tiles to extract from the image (which will be fed into
    the localization system).
  > 'batch_size' - size of batches of tiles fed into the localization system.
  > 'localization_model' - cell localization model to use (a heatmap regressor or a
    segmentator).
  > 'extraction_method' - string, the name of the coordinate extraction method used ('nms'
    for non-maximum suppression or 'cca' for connected component analysis).

Additional inputs:
  > 'num_workers' - number of workers for the DataLoader object used to load tiles
    before feeding them into the localization system. Default: 2.
  > 'bound_min_len' - hyperparameter sent to 'process_wsi_annots'. Same default (1000).
  > 'intersection_part' - hyperparameter sent to 'extract_tiles'. Same default (0.5).
  > 'window_shape' - hyperparameter sent to 'identify_diameters'. Same default ((110,110)).

Return a dictionary containing all relevant information regarding the sample. The
  dictionary has the following entries:
  > 'img_handle' - opened image handle.
  > 'img_areas' - the list returned from process_wsi_annots'
  > 'microns_per_pixel' - scale of the image.
  > 'cell_centers' - a numpy array of shape (num_of_cells,2), whose rows are the
    coordinates of the centers of the cells.
  > 'cell_diameters' - a numpy array of shape (num_of_cells,2,2), whose (n,i)'th
    entry is the coordinate pair of the i'th edge of the diameter of the n'th
    detected cell.
'''
def process_sample(tiff_path, xml_path, tile_shape, batch_size,
                   localization_model, extraction_method, num_workers=2,
                   bound_min_len=1000, intersection_part=0.5, window_shape=(110,110)):
    # open an image handle for the image
    img_handle = Tiff(tiff_path)

    # process the xml annotations file and extract tiles
    img_areas, microns_per_pixel = process_wsi_annots(xml_path,
                                                      bound_min_len=bound_min_len)
    tiles = extract_tiles(img_areas, img_handle.shape[:2], tile_shape,
                          intersection_part=intersection_part)

    # create tiles dataset and attach a dataloader to it
    tiles_dataset = TilesDataset(img_handle, tiles, tile_shape)
    tiles_dataloader = DataLoader(tiles_dataset, batch_size=batch_size,
                                  num_workers=num_workers)

    # detect cells and identify their diameters
    cell_coords_tensor = detect_cells(tiles_dataloader, localization_model,
                                      extraction_method)
    diameters_tensor = identify_diameters(cell_coords_tensor, img_handle,
                                          window_shape=window_shape)

    # organize information dictionary
    result = {'img_handle' : img_handle,
              'img_areas' : img_areas,
              'microns_per_pixel' : microns_per_pixel,
              'cell_centers' : np.array(cell_coords_tensor),
              'cell_diameters' : diameters_tensor}
    return result

Evaluation Using Seen Data¶

The entire system was developed and trained according to the single putamen sample we've seen so far - sample ID12718. In this part we'll evaluate our system using this sample. Although it is considered "seen data", it can still serve as a good way to measure the power of our system, for 2 main reasons:

  • First, the per-cell diameter detection method was developed according to very few of the CR+ neurons in the sample - ones that we manually observed and analyzed. Thus, most of the image (that contained more than 3,000 marked cells) did not affect the developed method at all.

  • In contrast, the localization networks were trained using data from the entire image. However, we could still see that the output they produced for it was not very ideal. Thus, overfitting is less probable.

Let's run the system on that sample:

In [ ]:
# record starting time
start_time = time.time()

# process sample
info_dictionary = process_sample(
     tiff_path='samples/ID12718.tiff',
     xml_path='samples/ID12718.xml',
     tile_shape=(512,2048), batch_size=8, num_workers=2,
     localization_model=heatmap_regressor, extraction_method='nms'
)

# extract relevant info from the returned dictionary
img_handle = info_dictionary['img_handle']
img_areas = info_dictionary['img_areas']
microns_per_pixel = info_dictionary['microns_per_pixel']
cell_centers = info_dictionary['cell_centers']
cell_diameters = info_dictionary['cell_diameters']

# inform the user about the time elapsed
time_elapsed = time.time() - start_time
s = time_elapsed % 60
m = int(time_elapsed // 60)
print(f'\n{len(cell_centers)} CR+ neurons detected and processed in {m}:{s:05.2f}',
      'minutes.')
Detecting CR+ neurons: 100%|██████████| 89/89 [01:39<00:00,  1.11s/it]
Identifying cell diameters: 100%|██████████| 3722/3722 [04:20<00:00, 14.31it/s]
3722 CR+ neurons detected and processed in 6:02.08 minutes.

Let's take a look at a few of predictions:

In [ ]:
# oganize figure
plt.figure(figsize=(15,15))
plt.suptitle('Predictions on Putamen Sample ID12718', fontsize='xx-large', y=0.94)

# plot predictions
DSProjectUtils.plot_predictions(img_handle, cell_centers, cell_diameters,
                                to_plot=range(40,56,1))

We see that the system did a good job with the detected cells:

  1. The detected cells are actual cells.
  2. The cells detected are different from one another.
  3. The diameter predictions fit.

We now look at the distribution of the detected diameter lengths.

In [ ]:
# calculate lengths from endpoints
cells_lengths_pix =np.array([((p0[0]-p1[0])**2+(p0[1]-p1[1])**2)**(1/2.) for (p0,p1) in cell_diameters])
cell_lengths_um=cells_lengths_pix*info_dictionary['microns_per_pixel']
In [ ]:
# plot lengths
sns.histplot(cell_lengths_um)
plt.title('Predicted Diameter Lengths (ums)')
plt.xlabel('Length (ums)')
plt.show()

It looks like a healthy distribution: the majority of lengths is below 20ums. To further examine the distribution, we compare it to the ground-truth distribution of the lengths.

Recall that the evaluation was done on the same image we examined in the EDA section. The data about the cells in the image is stored in the frame variable.

In [ ]:
# plot prediction and GT length distributions
plt.title('Length Distributions')
sns.histplot(cell_lengths_um, color='purple', label='Prediction')
sns.histplot(frame['LengthMicrons'], color='yellow', label='Ground Truth')
plt.legend()
plt.show()

We see that the predicted distribution overall captures the GT distribution. Therefore, for uses where the lengths are taken in to account collectively, these are good predictions.

Note that the count of the detected cells could be different from the number of cells that actually exist in the image, as shown below.

In [ ]:
print('Number of GT cells:', len(frame['Length']))
print('Number of predicted cells:', len(cell_centers))
Number of GT cells: 3182
Number of predicted cells: 3722

As we see, our model detected approximately 500 unmarked (possible) cells!

Evaluation using Unseen Data¶

Here we carry out the same procedure as before, but on a new WSI (ID12719) that hasn't been used throughout the project. This WSI has been used and referenced only in this section (in fact we got our hands on this WSI only when we got to this point in the project).

Since the evaluation of the percell method was adequately performed in the previous section, this evaluation is mainly for the better evaluation of the localization model.

We first glimpse into the new sample, and then carry out the evaluation.

A Glimpse at the Sample¶

Figure 15: Putamen sample ID12719, divided into its relevant regions.

Because we aren't familiar with the new sample, first observe the sample to see its general characteristics.

We load the XML file to look at the relevant distributions of the cells. The loading process was done already at the EDA, which was encapsulated into a dedicated function in the utils script.

In [ ]:
test_frame=DSProjectUtils.load_xml('samples/ID12719.xml')

We glance into the xml:

In [ ]:
test_frame.head()
Out[ ]:
Id Type Zoom Selected ImageLocation ImageFocus Length Area LengthMicrons AreaMicrons Text NegativeROA InputRegionId Analyze DisplayId X0 Y0 X1 Y1
0 297823 4 1 0 0 18.4 0.0 9.2 0.0 0 0 0 2 1395 24036 1379 24045
1 297824 4 1 0 0 18.0 0.0 9.0 0.0 0 0 0 3 1433 23884 1434 23902
2 297825 4 1 0 0 21.3 0.0 10.7 0.0 0 0 0 4 1692 24411 1676 24425
3 297826 4 1 0 0 17.5 0.0 8.8 0.0 0 0 0 5 1734 24392 1738 24409
4 297827 4 1 0 0 19.6 0.0 9.9 0.0 0 0 0 6 1761 24833 1756 24852

We look at the sample to see whether it is similar to the one we are acquainted with. The two measures are the cell-lengths in the WSIs and the number of cells present.

We now count the number of cells in the two samples:

In [ ]:
print(f'Train has {len(frame)} cells.')
print(f'Test has {len(test_frame)} cells.')
Train has 3182 cells.
Test has 3521 cells.

The test sample contains $\sim$400 more samples compared to the train sample. As we will see in the next section, the cells themselves wear the familiar shape and style as the ones in the training sample.

With look at the cell-length distributions:

In [ ]:
# show the two length distributions at the same plot
plt.title('Cell Length Distribution (ums)')
sns.histplot(frame['LengthMicrons'], color='blue', label='Train WSI')
sns.histplot(test_frame['LengthMicrons'], color='red', label='Test WSI')
plt.legend()
plt.show()

We see that the distributions match in shape: a high spike at around 10ums, with decreasing quantities to the right.

The test sample has more cells above the 10ums length range than the train sample.

Evaluation¶

We now evaluate the end-to-end system on the test sample. The code and flow in this section is essentialy the same as the one for the Train sample.

We now carry out the evaluation on the new WSI:

In [ ]:
# record starting time
start_time = time.time()

# process sample
test_info_dictionary = process_sample(
     tiff_path='samples/ID12719.tiff',
     xml_path='samples/ID12719.xml',
     tile_shape=(512,2048), batch_size=8, num_workers=2,
     localization_model=heatmap_regressor, extraction_method='nms'
)

# extract relevant info from the returned dictionary
test_img_handle = test_info_dictionary['img_handle']
test_img_areas = test_info_dictionary['img_areas']
test_microns_per_pixel = test_info_dictionary['microns_per_pixel']
test_cell_centers = test_info_dictionary['cell_centers']
test_cell_diameters = test_info_dictionary['cell_diameters']

# inform the user about the time elapsed
time_elapsed = time.time() - start_time
s = time_elapsed % 60
m = int(time_elapsed // 60)
print(f'\n{len(cell_centers)} CR+ neurons detected and processed in {m}:{s:05.2f}',
      'minutes.')
Detecting CR+ neurons: 100%|██████████| 104/104 [02:19<00:00,  1.35s/it]
Identifying cell diameters: 100%|██████████| 4289/4289 [05:13<00:00, 13.67it/s]
3722 CR+ neurons detected and processed in 7:38.73 minutes.

We start with the numerical examinations.

In [ ]:
print('Number of GT cells:', len(test_frame['Length']))
print('Number of predicted cells:', len(test_cell_centers))
Number of GT cells: 3521
Number of predicted cells: 4289

On the Test sample we predicted much more cells than those marked in it: 800 cells!

Lets see the length distributions. We start by calculating the lengths from the predicted diameter endpoints:

In [ ]:
# calculate lengths from endpoints
test_cells_lengths_pix =np.array([((p0[0]-p1[0])**2+(p0[1]-p1[1])**2)**(1/2.) for (p0,p1) in test_cell_diameters])
test_cell_lengths_um=test_cells_lengths_pix*test_microns_per_pixel

And progress with plotting the length distributions:

In [ ]:
# plot prediction and GT length distributions
plt.title('Length Distributions')
sns.histplot(test_cell_lengths_um, color='purple', label='Prediction')
sns.histplot(test_frame['LengthMicrons'], color='yellow', label='Ground Truth', binrange=(0,60))
plt.legend()
plt.show()

As seen above, the ground truth length distribution resembles the one predicted. The predicted distribution amplifies the spike of the ground truth distribution at the expense of lower quantities at the right hand of the spike.

And now the part we've all been waiting for: looking at the predictions!

In [ ]:
# oganize figure
plt.figure(figsize=(15,15))
plt.suptitle('Predictions on Putamen Sample ID12719', fontsize='xx-large', y=0.94)

# plot predictions
DSProjectUtils.plot_predictions(test_img_handle, test_cell_centers, test_cell_diameters,
                                to_plot=range(0,20,1))

The results are fantastic, with the same notes of the previous section applicable here: separate cells and fine diameter detection.

Because looking at nice results is always enjoyable, we finish this section with some more predictions. We already plotted cells 0-19, and now plot other cells:

In [ ]:
# oganize figure
plt.figure(figsize=(15,15))
plt.suptitle('Predictions on Putamen Sample ID12719', fontsize='xx-large', y=0.94)

# plot predictions
DSProjectUtils.plot_predictions(test_img_handle, test_cell_centers, test_cell_diameters,
                                to_plot=range(20,56,1))

Conclusion¶

We presented our system to Doctor Paz Kelmer, and he was very impressed with its performance. He even mentioned that the automatically-extracted cellmarks tend to be even more accurate than the ones that were manually placed by his team. He was glad that he got to work with us, and he plans to show the developed system to his team and may even make extensive use of it in the future. All in all, we are super happy that we got to work on such an important and interesting project. We hope that our algorithms may contribute to the field of neuroscience in the future, and if not - at least we got to work on a very interesting computer vision task :)

All things considered, we may conclude that the project was a success!

Suggestions for Future Work¶

Although the system's performance was quite adequate, there are still a few research directions we may explore and experiment with in order to make it even better. This is the place to list a few of the ideas we have in mind:

  1. Crafting a better deep segmentation model, through an improved dataset.

    We can manually create a small dataset of cell segmentations, or alternatively use the cell-perimeter markings provided in the XML annotations file (which we ignored in our algorithms due to how few there were). This could extend upon the cell segmentations already available. Given such an accurate dataset, a transfer learning segmentation model could be trained to create an improved one. Transfer learning is used to compensate for the small size of the hand-crafted dataset. This model is expected to outperform the current segmentation model, as its labels will be marked superbly and not by a noisy image processing segmentator.

  1. Detecting diameters based on the predicted segmentation maps.

    Instead of using image processing in order to create the cell segmentation maps, we may use the output of the segmentation model itself - which was fit for such "primitive" segmentations, but generalized to more accurate maps. The rest of the percell method will stay the same (recognizing the two furthest apart points on the contour of the segmentation map).

  1. Adding an in-between cell classifier.

    As we've seen, both localization methods produced a few false positives. These may merely be CR+ neurons that the research team genuinely missed, but either way - we could cope with them by adding a classifier that will filter such false positives wisely. The classifier can be fit using the regions that the localization systems detected a cell in them, and the ground truth knowledge about whether a cell is present or not.

Good Luck ^_^¶